# 2. Fit SimCLR

In [None]:
import pandas as pd
import torch
torch.manual_seed(0)

%load_ext autoreload 
%autoreload 2

In [None]:
df = pd.read_csv('./data/hpa_v21_kidney.csv',index_col=0)

# Some images are associated with multiple genes (i.e. nonspecific antibody); we remove these.
df['duplicated'] = df.index.value_counts()[df.index] > 1

# Only include high-quality images.
train_df = ( df.query('(Staining=="high")|(Staining=="medium")')
               .query('Reliability=="Enhanced"')
               .query('~duplicated') )

train_images = set(train_df.index)

print(len(train_images), 'training images')
print('covering', len(train_df['Gene'].unique()), 'genes')

In [None]:
from src.datamodule import ContrastiveDataModule

dm = ContrastiveDataModule(
    './data/images',
    image_ext='png8',
    image_size=512,
    patch_size=256,
    batch_size=16, # 150 !!!
    indicator=lambda item: item['Image'] in train_images,
    grouper=lambda item: item['Gene'],
    cache_dir='$TMPDIR',
    num_workers=4,
    random_state=0
)

In [None]:
from src.model import ContrastiveEmbedding

model = ContrastiveEmbedding(
    embedding_dim=128,
    encoder_type='densenet121',
    temperature=1.0,
    learning_rate=5e-4,
)

In [None]:
from src.tensorboard import start_tensorboard

start_tensorboard(login_node='login-2')

In [None]:
from pytorch_lightning import Trainer

num_epochs = 10

!rm -rf ./lightning_logs
trainer = Trainer(
    gpus=2,
    precision=16,
    strategy='dp',
    log_every_n_steps=5,
    min_epochs=num_epochs,
    max_epochs=num_epochs
)

trainer.fit(model, dm)