# 2. Fit SimCLR

In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
torch.manual_seed(0)

%load_ext autoreload 
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.model import ContrastiveEmbedding
from src.datamodule import ContrastiveDataModule
from src.util import start_tensorboard
from pytorch_lightning import Trainer

In [29]:
TRAIN = True
EVAL = True

IMAGE_SIZE = 512
PATCH_SIZE = 256
BATCH_SIZE = 150
NUM_WORKERS = 16
NUM_EPOCHS = 1000
EMBEDDING_DIM = 128
TEMPERATURE = 1.0
LEARNING_RATE = 5e-4
ENCODER = 'densenet121'
NUM_GPUS = 1 # torch.cuda.device_count()

POS_REJECTION = False
NEG_REJECTION = 'Patient'
GENE_GROUPING = True

TISSUE = 'kidney'
VERSION = 'final'

DATA_DIR = f'./data/images/{TISSUE}'
CACHE_DIR = f'$TMPDIR/images/{TISSUE}'
CHECKPOINT = f'./data/weights/{TISSUE}_{VERSION}.ckpt'

Prepare subset of training images where correct binding is plausible.

In [30]:
df = pd.read_csv(f'./data/hpa_v21_{TISSUE}.csv',index_col=0)

# Some images are associated with multiple genes (i.e. nonspecific antibody); we remove these.
df['duplicated'] = df.index.value_counts()[df.index] > 1

# Only train on high-quality images.
train_df = ( df.query('(Staining=="high")|(Staining=="medium")')
               .query('Reliability=="Enhanced"')
               .query('~duplicated') )

train_images = set(train_df.index)

print(len(train_images), 'images')
print(len(train_df['Gene'].unique()), 'genes')

10164 images
2106 genes


Fit the model.

In [5]:
if TRAIN:
    # kludge
    indicator = ('Image',train_images)
    grouper = 'Gene'
    
    model = ContrastiveEmbedding(
        embedding_dim=EMBEDDING_DIM,
        patch_size=PATCH_SIZE,
        encoder_type=ENCODER,
        temperature=TEMPERATURE,
        learning_rate=LEARNING_RATE,
        positive_masking=POS_REJECTION,
        negative_masking=NEG_REJECTION,
        # kludge
        image_dir=DATA_DIR,
        image_ext='png',
        image_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        indicator=indicator,
        grouper=grouper,
        cache_dir=CACHE_DIR,
        num_workers=NUM_WORKERS,
        random_state=0
    )
    
    dm = ContrastiveDataModule(
        DATA_DIR,
        image_ext='png',
        image_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        batch_size=BATCH_SIZE,
        indicator=indicator,
        grouper=grouper,
        cache_dir=CACHE_DIR,
        num_workers=NUM_WORKERS,
        random_state=0
    )
    
    # !rm -rf ./lightning_logs
    start_tensorboard(login_node='login-2')

    trainer = Trainer(
        gpus=NUM_GPUS,
        precision=16,
#         strategy='ddp_spawn',
        log_every_n_steps=5,
        min_epochs=NUM_EPOCHS,
        max_epochs=NUM_EPOCHS
    )

    trainer.fit(model, dm)

    # BEGIN KLUDGE
#     import sys
#     sys.path.insert(1, '../../ClusterTools')
#     from ClusterTools import *

#     dm.setup()
    
#     Cluster(
#         gpus=1,
#         cpus=20,
#         num_nodes=4,
#         version=f'{TISSUE}_{VERSION}',
#         log_dir='./lightning_logs',
#     ).train(
#         model, 
#         dm,
#         min_epochs=NUM_EPOCHS,
#         max_epochs=NUM_EPOCHS,
#         log_every_n_steps=5,
#         precision=16,
#         train_async=True
#     );
    # END KLUDGE
    
    [last_ckpt] = !ls -t1 ./lightning_logs/*/checkpoints/*.ckpt | head -n1
    !cp {last_ckpt} {CHECKPOINT}

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-590cd023-ff31-c1a2-8267-e0afba042725]
Set SLURM handle signals.

  | Name            | Type       | Params
-----------------------------------------------
0 | encoder         | Sequential | 7.1 M 
1 | projection_head | Sequential | 33.0 K
2 | _transform      | Sequential | 0     
-----------------------------------------------
7.1 M     Trainable params
0         Non-trainable params
7.1 M     Total params
14.236    Total estimated model params size (MB)


Epoch 99: 100%|██████████| 8/8 [00:18<00:00,  2.32s/it, loss=0.00303, v_num=1.45e+7]


Embed the entire dataset of images.

In [27]:
if EVAL:
    # kludge
    indicator = lambda _: True
    grouper = 'Image'
    
    model = ContrastiveEmbedding.load_from_checkpoint(
        CHECKPOINT,
        embedding_dim=EMBEDDING_DIM,
        patch_size=PATCH_SIZE,
        encoder_type=ENCODER,
        temperature=TEMPERATURE,
        learning_rate=LEARNING_RATE,
        positive_masking=POS_REJECTION,
        negative_masking=NEG_REJECTION,
    )
    
    dm = ContrastiveDataModule(
        DATA_DIR,
        image_ext='png',
        image_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        batch_size=BATCH_SIZE,
        indicator=indicator,
        grouper=grouper,
        cache_dir=CACHE_DIR,
        num_workers=NUM_WORKERS,
        random_state=0
    )
    dm.setup()
    
    model.eval()
    model = model.cuda()

    embeddings = []
    images = []
    for item in tqdm(dm.test_dataset, position=0):
        with torch.no_grad():
            z = model(item['image'].to(model.device).unsqueeze(0))
        images.append(item['name'])
        embeddings.append(z.detach().cpu().numpy())
    embeddings = np.concatenate(embeddings, axis=0)

    # dump to file
#     np.save(f'./data/{TISSUE}_embeddings.npy', embeddings)
#     with open(f'./data/{TISSUE}_embeddings.txt','w') as f:
#         f.write('\n'.join(images))
    embeddings = pd.DataFrame(embeddings, index=images)
    embeddings.to_csv(f'./data/{TISSUE}_{VERSION}_embeddings.csv',sep=',')
        
    len(embeddings)

100%|██████████| 45740/45740 [15:21<00:00, 49.64it/s] 
