# 2. Fit SimCLR

In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
torch.manual_seed(0)

%load_ext autoreload 
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.model import ContrastiveEmbedding
from src.datamodule import ContrastiveDataModule
from src.tensorboard import start_tensorboard
from pytorch_lightning import Trainer

In [3]:
TRAIN = False
EVAL = True

IMAGE_SIZE = 512
PATCH_SIZE = 256
BATCH_SIZE = 150
NUM_WORKERS = 16
NUM_EPOCHS = 1000
EMBEDDING_DIM = 128
TEMPERATURE = 1.0
LEARNING_RATE = 5e-4
ENCODER = 'densenet121'

TISSUE = 'kidney'
DATA_DIR = f'./data/images/{TISSUE}'
CACHE_DIR = f'$TMPDIR/images/{TISSUE}'

CHECKPOINT = f'./data/{TISSUE}.ckpt'

Prepare subset of training images where correct binding is plausible.

In [4]:
df = pd.read_csv('./data/hpa_v21_kidney.csv',index_col=0)
df = df.query(f'Tissue=="{TISSUE}"')

# Some images are associated with multiple genes (i.e. nonspecific antibody); we remove these.
df['duplicated'] = df.index.value_counts()[df.index] > 1

# Only include high-quality images.
train_df = ( df.query('(Staining=="high")|(Staining=="medium")')
               .query('Reliability=="Enhanced"')
               .query('~duplicated') )

train_images = set(train_df.index)

print(len(train_images), 'images')
print(len(train_df['Gene'].unique()), 'genes')

10164 images
2106 genes


Fit the model.

In [5]:
if TRAIN:
    model = ContrastiveEmbedding(
        embedding_dim=EMBEDDING_DIM,
        patch_size=PATCH_SIZE,
        encoder_type=ENCODER,
        temperature=TEMPERATURE,
        learning_rate=LEARNING_RATE,
    )
    
    dm = ContrastiveDataModule(
        DATA_DIR,
        image_ext='png',
        image_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        batch_size=BATCH_SIZE,
        indicator=lambda item: item['Image'] in train_images,
        grouper=lambda item: item['Gene'],
        cache_dir=CACHE_DIR,
        num_workers=NUM_WORKERS,
        random_state=0
    )
    
    # !rm -rf ./lightning_logs
    start_tensorboard(login_node='login-2')

    trainer = Trainer(
        gpus=2,
        precision=16,
        strategy='dp',
        log_every_n_steps=5,
        min_epochs=NUM_EPOCHS,
        max_epochs=NUM_EPOCHS
    )

    trainer.fit(model, dm)
    
    [last_ckpt] = !ls -t1 ./lightning_logs/default/{version}/checkpoints/*.ckpt | head -n1
    !cp {last_ckpt} {CHECKPOINT}

Embed the entire dataset of images.

In [None]:
if EVAL:
    model = ContrastiveEmbedding.load_from_checkpoint(
        CHECKPOINT,
        embedding_dim=EMBEDDING_DIM,
        patch_size=PATCH_SIZE,
        encoder_type=ENCODER,
        temperature=TEMPERATURE,
        learning_rate=LEARNING_RATE
    )
    
    dm = ContrastiveDataModule(
        DATA_DIR,
        image_ext='png',
        image_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        batch_size=BATCH_SIZE,
        indicator=lambda item: True,
        grouper=lambda item: item['Gene'],
        cache_dir=CACHE_DIR,
        num_workers=NUM_WORKERS,
        random_state=0
    )
    dm.setup()

    model.eval()
    model = model.cuda()

    embeddings = []
    images = []
    for item in tqdm(dm.test_dataset, position=0):
        with torch.no_grad():
            z = model(item['image'].to(model.device).unsqueeze(0))
        images.append(item['name'])
        embeddings.append(z.detach().cpu().numpy())
    embeddings = np.concatenate(embeddings, axis=0)

    # dump to file
    np.save(f'./data/{TISSUE}_embeddings.npy', embeddings)
    with open(f'./data/{TISSUE}_embeddings.txt','w') as f:
        f.write('\n'.join(images))
        
    len(embeddings)

 60%|█████▉    | 36703/61539 [13:14<08:32, 48.50it/s]