# 2. Fit SimCLR

In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
torch.manual_seed(0)

%load_ext autoreload 
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv('./data/hpa_v21_kidney.csv',index_col=0)

# Some images are associated with multiple genes (i.e. nonspecific antibody); we remove these.
df['duplicated'] = df.index.value_counts()[df.index] > 1

# Only include high-quality images.
train_df = ( df.query('(Staining=="high")|(Staining=="medium")')
               .query('Reliability=="Enhanced"')
               .query('~duplicated') )

train_images = set(train_df.index)

print(len(train_images), 'images')
print(len(train_df['Gene'].unique()), 'genes')

10164 images
2106 genes


In [3]:
from src.datamodule import ContrastiveDataModule

IMAGE_SIZE = 512
PATCH_SIZE = 256
BATCH_SIZE = 150
NUM_WORKERS = 16
GROUP_BY = 'Gene'

dm = ContrastiveDataModule(
    './data/images',
    image_ext='png',
    image_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    batch_size=BATCH_SIZE,
    indicator=lambda item: item['Image'] in train_images,
    grouper=lambda item: item[GROUP_BY],
    cache_dir='$TMPDIR/images',
    num_workers=NUM_WORKERS,
    random_state=0
)

In [33]:
from src.model import ContrastiveEmbedding
from pytorch_lightning import Trainer
from src.tensorboard import start_tensorboard

TRAIN = False
NUM_EPOCHS = 10
EMBEDDING_DIM = 128
TEMPERATURE = 1.0
LEARNING_RATE = 5e-4
ENCODER = 'densenet121'

if TRAIN:
    model = ContrastiveEmbedding(
        embedding_dim=EMBEDDING_DIM,
        patch_size=PATCH_SIZE,
        encoder_type=ENCODER,
        temperature=TEMPERATURE,
        learning_rate=LEARNING_RATE,
    )
    
    # !rm -rf ./lightning_logs
    
    start_tensorboard(login_node='login-2')

    trainer = Trainer(
        gpus=2,
        precision=16,
        strategy='dp',
        log_every_n_steps=5,
        min_epochs=NUM_EPOCHS,
        max_epochs=NUM_EPOCHS
    )

    trainer.fit(model, dm)
    
else:
#     [last_ckpt] = !ls -t1 ./lightning_logs/*/checkpoints/*.ckpt | head -n1
    last_ckpt = './data/kidney.ckpt'
    model = ContrastiveEmbedding.load_from_checkpoint(
        last_ckpt,
        embedding_dim=EMBEDDING_DIM,
        patch_size=PATCH_SIZE,
        encoder_type=ENCODER,
        temperature=TEMPERATURE,
        learning_rate=LEARNING_RATE
    )
    dm.setup()

In [47]:
# Evaluate each embedding on the whole image, for the downstream classification task.

model.eval()
model = model.cuda()

embeddings = []
images = []
for batch in tqdm(dm.test_dataloader(), position=0):
    with torch.no_grad():
        z = model(batch['image'].to(model.device))
    images.append(batch['Image'])
    embeddings.append(z.detach().cpu().numpy())
embeddings = np.concatenate(embeddings, axis=0)
images = np.concatenate(images, axis=0)

np.save('./data/embeddings.npy', embeddings)
with open('./data/embeddings.txt','w') as f:
    f.write('\n'.join(images))