### Reproduce kaggle notebook

#### Assumptions: 
1. We have a trained feature extractor saved in `/scr/mdoron/Dino4Cells/models/data-fixed_size_masked/fixed_unmasked_checkpoint0100.pth`
2. The feature extractor was trained on fixed, unmasked cells who were segmented with Cellpose
3. The feature extractor was trained without self-normalization

In [22]:
# Imports
import os
from pathlib import Path

### Training classifier head

##### Slow training with dynamic feature extraction:

In [4]:
config_path = 'configs/config_train_classifier_head_single_cells.yaml'
feature_extractor_path = '/scr/mdoron/Dino4Cells/models/data-fixed_size_masked/fixed_unmasked_checkpoint0100.pth'

os.system(f'''python run_end_to_end.py --config {config_path} --gpus 0,1,2,3 
          --feature_extractor_state_dict {feature_extractor_path}
          --use_pretrained_features False
          ''')

assert(Path(f"{config['classification']['output_dir']/config['classification']['output_prefix']}/classifier_final.pth").exists())

##### Faster training with state features:

First, extracting the features, saving them to `configs['embedding']['output_path']`

In [None]:
# Extracting the features, saving them to `configs['embedding']['output_path']`
os.system(f'CUDA_VISIBLE_DEVICES=0,1,2,3 python run_get_features.py --config {config_path} --gpus 0,1,2,3')
train_path = '/scr/mdoron/Dino4Cells/models/without-norm-fixed-unmasked-mom-0.996-epoch-100/train_features_averaged.pth'
valid_path = '/scr/mdoron/Dino4Cells/models/without-norm-fixed-unmasked-mom-0.996-epoch-100/valid_features_averaged.pth'

# Average-pooling the features, saving them in `train_path` and `valid_path`
os.system(f'python prepare_averaged_features.py --config {config_path} --train_path {train_path} --valid_path {valid_path}')

assert(Path(train_path).exists())
assert(Path(valid_path).exists())

Then, train a classifier head with static features

In [None]:
# training classifier head
os.system(f'''python run_end_to_end.py --config {config_path} --gpus 0,1,2,3 
          --use_pretrained_features True
          --train_path 
          --valid_path 
          ''')

assert(Path(f"{config['classification']['output_dir']/config['classification']['output_prefix']}/classifier_final.pth").exists())

Both dynamic and static options will save the classifier head state dict to `config['classification']['output_dir']['output_prefix']/classifier_final.pth`

In [24]:
assert(Path(f"{config['classification']['output_dir']}/{config['classification']['output_prefix']}/classifier_final.pth").exists())

AssertionError: 

### Extracting kaggle test features

### Producing kaggle submission file

In [25]:
os.system('python run_end_to_end.py --config {config_path} --gpus 0,1,2,3')

'/scr/mdoron/Dino4Cells/results//single_cell_fixed_masked_simpleClf/classifier_final.pth'

In [13]:
from pprint import pprint
pprint(config['classification'])

{'averaged_features': '/scr/mdoron/Dino4Cells/supervised_vit_512/features.pth',
 'balance': True,
 'batch_size_per_gpu': 512,
 'cell_type_classifier': '/scr/mdoron/Dino4Cells/supervised_vit_512/cell_type_classifier.pth',
 'classifier_state_dict': '/scr/mdoron/Dino4Cells/results//single_cell_fixed_masked/SimpleClf_e1_lr1e-3_div25_wd1e-4_p0.5_BCELoss_v1_checkpoint_final.pth',
 'classifier_type': 'simple_clf',
 'competition_type': 'single_cells',
 'dropout': 0.5,
 'early_stopping': -1,
 'epochs': 10,
 'feature_extractor_state_dict': None,
 'infer_test_data': False,
 'loss': 'BCEWithLogitsLoss',
 'lr': 0.0001,
 'min_lr': '1e-6',
 'n_layers': 2,
 'n_units': 1024,
 'num_classes': 19,
 'num_workers': 14,
 'optimizer': 'AdamW',
 'output_dir': '/scr/mdoron/Dino4Cells/results/',
 'output_prefix': 'single_cell_fixed_masked_simpleClf',
 'overwrite': True,
 'prepare_kaggle_submission': True,
 'protein_classifier': '/scr/mdoron/Dino4Cells/supervised_vit_512/protein_classifier.pth',
 'protein_task': 

In [30]:
from skimage import io
io.imread('/scr/mdoron/Dino4Cells/data/single_cell_test_data_varied_unmasked/0040581b-f1f2-4fbe-b043-b6bfea5404bb_1.png')

array([[[ 15,   0,   0,   0],
        [ 13,   0,   0,   0],
        [ 11,   2,   0,   0],
        ...,
        [  0,   0,   0,   0],
        [  0,   0,   0,   0],
        [  0,   0,   0,   0]],

       [[ 12,   9,   0,   1],
        [ 14,   1,   0,   4],
        [ 13,   6,   0,   0],
        ...,
        [  0,   0,   0,   0],
        [  0,   0,   0,   0],
        [  0,   0,   0,   0]],

       [[ 14,   3,   0,   1],
        [ 14,   8,   0,   0],
        [ 12,   5,   0,   0],
        ...,
        [  0,   0,   0,   0],
        [  0,   0,   0,   0],
        [  0,   0,   0,   0]],

       ...,

       [[  0,   0,   0,   0],
        [  0,   0,   0,   0],
        [  0,   0,   0,   0],
        ...,
        [ 32,  25,   0,  56],
        [ 24,  14,   0,  32],
        [ 19,   4,   0,  19]],

       [[  0,   0,   0,   0],
        [  0,   0,   0,   0],
        [  0,   0,   0,   0],
        ...,
        [ 29,   6,   0, 102],
        [ 24,   6,   0,  80],
        [ 23,  35,   0,  59]],

       [[  0