## Codebook Design via Variational Reinforcement Learning

In [None]:
import os
import tarfile
from ruamel_yaml import YAML
from ruamel_yaml.comments import CommentedMap
import logging
import pprint
import time
import math

import matplotlib.pylab as plt
import numpy as np
import torch
from boltons.cacheutils import cachedproperty, cachedmethod
from typing import Dict, Any, Callable, Optional, Dict, List, Set, Tuple, Union

from mighty_codes import consts
from mighty_codes import metric_utils
from mighty_codes import experiments

from mighty_codes.torch_utils import \
    to_np, \
    to_torch, \
    to_one_hot_encoded

from mighty_codes.experiments import \
    ChannelModelSpecification

from mighty_codes.rl.mrf_one_body import \
    OneBodyMRFPotential, \
    NeuralPIOneBodyMRFPotential

from mighty_codes.rl.mrf_two_body import \
    TwoBodyMRFPotential, \
    NeuralGeneralPITwoBodyMRFPotential, \
    TwoBodyPotentialPrefactor, \
    BiLinearTwoBodyPotentialPrefactor, \
    BiLinearTwoBodyPotentialPrefactorSimple, \
    NeuralTwoBodyPotentialPrefactor, \
    ExpBiLinearPITwoBodyMRFPotential

from mighty_codes.rl.mrf_code_gen import \
    MRFCodebookGenerator, \
    code_space_filter_func_rel_symbol_weight, \
    code_space_filter_func_abs_hamming_weight


yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)

log_info = print

In [None]:
device_gen = torch.device("cuda:0")
device_eval = torch.device("cuda:0")
dtype = torch.float32

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False

## Setup codebook generator

In [None]:
params = {
    
    'assert_shapes': True,
    'experiment_prefix': 'rl_bac',
    
    'min_rel_symbol_weight_s': [0.125, 0.125],
    'max_rel_symbol_weight_s': [0.875, 0.875],
    'min_n_types': 50,  # 16
    'max_n_types': 300,  # 256
    'min_code_length': 16,  # 6
    'max_code_length': 16,  # 16
    'min_source_nonuniformity': 10.,  # 1.
    'max_source_nonuniformity': 1000.,
    
    'experiments_per_iter': 1,
    'batch_size': 10,  # 4
    'n_iters': 1_000_000,
    'action_policy': 'sampled',
    'action_policy_epsilon': 0.5,
    'use_running_average_baseline': False,
    'baseline_ma_beta': 0.9,
    
    # BAC channel
    'channel_model': 'channel_bac_merfish',
    'decoder_type': 'posterior_sampled',
    'decoder_kwargs': {
        'split_size': 1},  # 1

    # Gaussian channel
    # 'channel_model': 'channel_gaussian_merfish',
    # 'decoder_type': 'posterior_sampled',
    # 'decoder_kwargs': {
    #     'n_samples_per_type': 1_000,
    #     'max_n_samples_per_type_per_sampling_round': 1_000},

    'optim_type': 'adam',
    'optim_kwargs': {
        'lr': 1e-2,
        'betas': (0.5, 0.9)},

    # 'optim_type': 'rmsprop',
    # 'optim_kwargs': {'lr': 1e-3, 'alpha': 0.99},

    # 'optim_type': 'adam',
    # 'optim_kwargs': {'lr': 1e-3, 'betas': (0.5, 0.99)},

    # 'optim_type': 'rmsprop',
    # 'optim_kwargs': {'lr': 1e-4, 'alpha': 0.99},
    
    'metrics_dict_type': 'basic',
    'optimality_type': 'fdr',
    'metrics_kwargs': {},

    'disable_two_body': True,  # False <======== NOTE
    'enable_hardcore_two_body_potential': True,
    'two_body_max_n_interactions': 64,
    'two_body_dropout_policy': 'random',

    'top_proposals_per_column': 512,
    'random_proposals_per_column': 512,
    'top_k_noise_std': 0.5,
    
    'baseline_reduction_type': 'mean',
    
    'log_frequency': 1,
    'log_sig_digits': 4,
    'log_column_width_large': 30,
    'log_column_width_small': 16,
    'log_optimality_scale': 100.,
}

In [None]:
# fetch all channel models
all_channel_model_specs_dict: Dict[str, ChannelModelSpecification] = {}
for k, v in experiments.__dict__.items():
    if isinstance(v, ChannelModelSpecification):
        all_channel_model_specs_dict[k] = v
        
# select channel model
channel_spec = all_channel_model_specs_dict[params['channel_model']]
channel_model = channel_spec.channel_model.to(device_eval).type(dtype)

In [None]:
# one-body potential
one_body_nn_specs = [
    (6, 5),
    'elu',
    (5, 5),
    'elu', 
    (5, channel_model.n_symbols + 1)]

one_body_lp_norm = 2

one_body_potential = NeuralPIOneBodyMRFPotential(
    n_symbols=channel_model.n_symbols,
    nn_specs=one_body_nn_specs,
    lp_norm=one_body_lp_norm,
    assert_shapes=params['assert_shapes']).to(device_gen)

# two-body potential
n_components = 1

two_body_potential_prefactor_provider = BiLinearTwoBodyPotentialPrefactor(
    n_components=n_components,
    init_bilinear_scale=1.0,
    init_linear_scale=0.01,
    init_constant_scale=0.01).to(device_gen)

two_body_potential = ExpBiLinearPITwoBodyMRFPotential(
    n_symbols=channel_model.n_symbols,
    n_components=n_components,
    code_length_reduction_type='mean',
    prefactor_provider=two_body_potential_prefactor_provider,
    init_bilinear_diag_scale=0.05,
    init_bilinear_rand_scale=0.0,
    init_linear_scale=0.0,
    init_constant_scale=-5.0).to(device_gen)

# eta_nnet_specs = [(n_symbols + 1, 5), 'elu']
# xi_nnet_specs = [(5, 10), 'elu']
# psi_nnet_specs = [(10, 1)]

# two_body_potential = NeuralGeneralPITwoBodyMRFPotential(
#     n_symbols=n_symbols,
#     n_meta=0,
#     eta_nnet_specs=eta_nnet_specs,
#     xi_nnet_specs=xi_nnet_specs,
#     psi_nnet_specs=psi_nnet_specs).to(device_gen)


# two_body_potential_prefactor_provider = NeuralTwoBodyPotentialPrefactor(
#     n_components=n_components,
#     eta_nnet_specs=[(1, 5), 'elu'],
#     psi_nnet_specs=[(5, n_components)]).to(device_gen)


# two_body_potential_prefactor_provider = BiLinearTwoBodyPotentialPrefactorSimple(
#     n_components=n_components).to(device_gen)

In [None]:
# code space filter
code_space_filter_func = lambda c_als, code_length: \
    code_space_filter_func_rel_symbol_weight(
        c_als=c_als,
        code_length=code_length,
        min_rel_symbol_weight_s=np.asarray(params['min_rel_symbol_weight_s']),
        max_rel_symbol_weight_s=np.asarray(params['max_rel_symbol_weight_s']))

codebook_generator = MRFCodebookGenerator(
    one_body_potential=one_body_potential,
    two_body_potential=two_body_potential,
    n_symbols=channel_model.n_symbols,
    code_space_filter_func=code_space_filter_func,
    device=device_gen,
    dtype=dtype,
    assert_shapes=params['assert_shapes'])

In [None]:
# # optimizer
# optim_str_to_obj_map = {
#     'adam': torch.optim.Adam,
#     'rmsprop': torch.optim.RMSprop
# }

# optim = optim_str_to_obj_map[params['optim_type']](
#     params=list(codebook_generator.parameters()),
#     **params['optim_kwargs'])

optim = torch.optim.Adam([
    {'params': codebook_generator.one_body_potential.parameters(),
     'lr': 1e-2,
     'betas': (0.5, 0.9)},
    {'params': codebook_generator.two_body_potential.parameters(),
     'lr': 1e-3,
     'betas': (0.5, 0.9)}],
    lr=1e-2,
    betas=(0.5, 0.9))

## Optimization Loop

In [None]:
# codebook_generator.load_state_dict(torch.load('./codebook_generator_state__multi__4.pt'))
# optim.load_state_dict(torch.load('./optim__multi__4.pt'))

# # override two-body parameters
# two_body_potential.gamma_m.data = torch.tensor([-1.]).to(device_gen)
# two_body_potential.gamma_ms.data = torch.tensor([[0., 0.]]).to(device_gen)
# two_body_potential.gamma_mss_unconstrained.data = torch.tensor([[[0.2, 0.], [0., 0.2]]]).to(device_gen)

# codebook_generator.load_state_dict(
#     torch.load('./codebook_generator_state__greedy__140__MERFISH__full_interaction__test.pt'),
#     strict=False)

# codebook_generator.load_state_dict(torch.load('./codebook_generator_state__1.pt'))
# # optim.load_state_dict(torch.load('./optim__1.pt'))

In [None]:
i_iter = 0
optimality_hist = []
baseline_hist = []
baseline_running_average = 0.

In [None]:
# for g in optim.param_groups:
#     g['lr'] = 1e-3
#     g['betas'] = (0.5, 0.99)

In [None]:
# params['enable_hardcore_two_body_potential'] = False
# params['disable_two_body'] = False
# params['baseline_reduction_type'] = 'mean'

In [None]:
# logging parameters
log_frequency = params['log_frequency']
small_col = params['log_column_width_small']
large_col = params['log_column_width_large']
sig_digits = params['log_sig_digits']
log_optimality_scale = params['log_optimality_scale']

header_string = \
    f"{'i_iter'.ljust(small_col)}" \
    f"{'code_length'.ljust(small_col)}" \
    f"{'n_types'.ljust(small_col)}" \
    f"{'nonuniformity'.ljust(small_col)}" \
    f"{'optimality'.ljust(large_col)}" \
    f"{'optimality_compl'.ljust(large_col)}" \
    f"{'baseline'.ljust(small_col)}"

log_info(f"Logging optimality scale factor: {log_optimality_scale:.3f}")
log_info("")
log_info(header_string)
log_info('=' * len(header_string))

while i_iter < params['n_iters']:

    # zero grad
    optim.zero_grad()

    # accumulate grad
    for i_expmt in range(params['experiments_per_iter']):
        
        # generate experiment
        experiment_spec = experiments.generate_experiment_spec(
            name_prefix=params['experiment_prefix'],
            min_rel_symbol_weight_s=params['min_rel_symbol_weight_s'],
            max_rel_symbol_weight_s=params['max_rel_symbol_weight_s'],
            n_symbols=channel_model.n_symbols,
            min_code_length=params['min_code_length'],
            max_code_length=params['max_code_length'],
            min_n_types=params['min_n_types'],
            max_n_types=params['max_n_types'],
            min_source_nonuniformity=params['min_source_nonuniformity'],
            max_source_nonuniformity=params['max_source_nonuniformity'])

        # generate problem spec
        problem_spec = experiments.SingleEntityCodingProblemSpecification(
            experiment_spec=experiment_spec,
            channel_spec=channel_spec)

        # sample a batch of codebooks
        codebook_generator_output_dict = codebook_generator.forward(
            code_length=experiment_spec.code_length,
            n_types=experiment_spec.n_types,
            batch_size=params['batch_size'],
            pi_t=experiment_spec.pi_t,
            nu_tj=None,
            disable_two_body=params['disable_two_body'],
            two_body_max_n_interactions=params['two_body_max_n_interactions'],
            two_body_dropout_policy=params['two_body_dropout_policy'],
            top_proposals_per_column=params['top_proposals_per_column'],
            random_proposals_per_column=params['random_proposals_per_column'],
            enable_hardcore_two_body_potential=params['enable_hardcore_two_body_potential'],
            action_policy=params['action_policy'],
            action_policy_epsilon=params['action_policy_epsilon'],
            top_k_noise_std=params['top_k_noise_std'])

        # get the decoding confusion matrix
        decoder_output_dict = channel_model.get_weighted_confusion_matrix(
            codebook_btls=codebook_generator_output_dict['codebook_btls'].to(device_eval),
            pi_bt=to_torch(experiment_spec.pi_t, device=device_eval, dtype=dtype).expand(
                [params['batch_size'], experiment_spec.n_types]),
            decoder_type=params['decoder_type'],
            **params['decoder_kwargs'])

        # get metrics dict
        metrics_dict = metric_utils.get_metrics_dict_from_decoder_output_dict(
            decoder_output_dict=decoder_output_dict,
            metrics_dict_type=params['metrics_dict_type'],
            **params['metrics_kwargs'])

        # reduce metrics dict to optimality
        optimality_b = metric_utils.get_optimality_from_metrics_dict(
            metrics_dict=metrics_dict,
            optimality_type=params['optimality_type'])

        # baseline
        baseline_reduction = {
            'mean': torch.mean,
            'median': torch.median}[params['baseline_reduction_type']]
        optimality_batch_baseline = baseline_reduction(optimality_b).item()
        baseline_running_average = (
            params['baseline_ma_beta'] * baseline_running_average + 
            (1. - params['baseline_ma_beta']) * optimality_batch_baseline)
        baseline_running_average_unbiased = baseline_running_average / (1. - params['baseline_ma_beta'] ** (i_iter + 1))
        baseline = baseline_running_average_unbiased if params['use_running_average_baseline'] else optimality_batch_baseline
        optimality_sub_baseline_b = optimality_b - baseline

        # REINFORCE
        loss = -torch.dot(
            codebook_generator_output_dict['log_prob_b'],
            optimality_sub_baseline_b.to(device_gen))

        loss.backward()

        # bookkeeping
        optimality_hist.append(optimality_batch_baseline)
        baseline_hist.append(baseline_running_average_unbiased)

        if i_iter % params['log_frequency'] == 0:

            source_nonuniformity = (experiment_spec.pi_t[0] / experiment_spec.pi_t[-1]).item()
            log_optimality_batch_mean = log_optimality_scale * optimality_b.mean().item()
            log_optimality_batch_lo = log_optimality_scale * optimality_b.min().item() 
            log_optimality_batch_hi = log_optimality_scale * optimality_b.max().item() 
            log_optimality_compl_batch_mean = log_optimality_scale * (1. - optimality_b).mean().item() 
            log_optimality_compl_batch_lo = log_optimality_scale * (1. - optimality_b).min().item() 
            log_optimality_compl_batch_hi = log_optimality_scale * (1. - optimality_b).max().item()
            log_baseline = log_optimality_scale * baseline_running_average_unbiased

            log_string = \
                f"{str(i_iter).ljust(small_col)}" + \
                f"{experiment_spec.code_length}".ljust(small_col) + \
                f"{experiment_spec.n_types}".ljust(small_col) + \
                f"{source_nonuniformity:.{sig_digits}f}".ljust(small_col) + \
                f"{log_optimality_batch_mean:.{sig_digits}f} ({log_optimality_batch_lo:.{sig_digits}f}, {log_optimality_batch_hi:.{sig_digits}f})".ljust(large_col) + \
                f"{log_optimality_compl_batch_mean:.{sig_digits}f} ({log_optimality_compl_batch_lo:.{sig_digits}f}, {log_optimality_compl_batch_hi:.{sig_digits}f})".ljust(large_col) + \
                f"{log_baseline:.{sig_digits}f}".ljust(small_col);

            log_info(log_string)

    # gradient update
    optim.step()
        
    i_iter += 1

    torch.cuda.empty_cache()

## Visualize

In [None]:
plt.plot(optimality_hist)

In [None]:
plt.plot(baseline_hist)
plt.ylim((0, 1.0))

In [None]:
two_body_potential_prefactor_provider.beta_constant_m

In [None]:
two_body_potential_prefactor_provider.beta_linear_m

In [None]:
two_body_potential_prefactor_provider.beta_bilinear_m

In [None]:
two_body_potential.gamma_m

In [None]:
two_body_potential.gamma_ms

In [None]:
two_body_potential.gamma_mss

In [None]:
torch.save(codebook_generator.state_dict(), './codebook_generator_state__multi__16.pt')
torch.save(optim.state_dict(), './optim__multi__16.pt')

In [None]:
# codebook_generator.load_state_dict(
#     torch.load('./codebook_generator_state__greedy__140__MERFISH__full_interaction__test.pt'),
#     strict=False)

In [None]:
with torch.no_grad():
    
    # generate experiment
    experiment_spec = experiments.generate_experiment_spec(
        name_prefix=params['experiment_prefix'],
        min_rel_symbol_weight_s=params['min_rel_symbol_weight_s'],
        max_rel_symbol_weight_s=params['max_rel_symbol_weight_s'],
        n_symbols=channel_model.n_symbols,
        min_code_length=16,
        max_code_length=16,
        min_n_types=140,
        max_n_types=140,
        min_source_nonuniformity=1000.,
        max_source_nonuniformity=1000.)

    # generate problem spec
    problem_spec = experiments.SingleEntityCodingProblemSpecification(
        experiment_spec=experiment_spec,
        channel_spec=channel_spec)

    # sample a batch of codebooks
    codebook_generator_output_dict = codebook_generator.forward(
        code_length=experiment_spec.code_length,
        n_types=experiment_spec.n_types,
        batch_size=1,
        pi_t=experiment_spec.pi_t,
        nu_tj=None,
        disable_two_body=params['disable_two_body'],   # NOTE
        two_body_max_n_interactions=256,
        two_body_dropout_policy='random',
        top_proposals_per_column=2048,
        random_proposals_per_column=2048,
        enable_hardcore_two_body_potential=True,
        action_policy='greedy',
        action_policy_epsilon=0.,
        top_k_noise_std=0.)

    # get the decoding confusion matrix
    decoder_output_dict = channel_model.get_weighted_confusion_matrix(
        codebook_btls=codebook_generator_output_dict['codebook_btls'].to(device_eval),
        pi_bt=to_torch(experiment_spec.pi_t, device=device_eval, dtype=dtype).expand(
            [1, experiment_spec.n_types]),
        decoder_type=params['decoder_type'],
        **params['decoder_kwargs'])

    # get metrics dict
    metrics_dict = metric_utils.get_metrics_dict_from_decoder_output_dict(
        decoder_output_dict=decoder_output_dict,
        metrics_dict_type=params['metrics_dict_type'],
        **params['metrics_kwargs'])

    # reduce metrics dict to optimality
    optimality_b = metric_utils.get_optimality_from_metrics_dict(
        metrics_dict=metrics_dict,
        optimality_type=params['optimality_type'])

# baseline
best_idx = torch.argmax(optimality_b)
print(f'FDR: {1. - optimality_b[best_idx]:.3f}')

In [None]:
codebook_btls = codebook_generator_output_dict['codebook_btls']
idx = 0

# get binary codebook
codebook_tl = codebook_btls[idx, :, :, 1].detach().cpu().numpy()

# sort codebook
codebook_tl = codebook_tl[:, np.lexsort(codebook_tl[::-1, :])[::-1]]

# aux
hamming_weight_t = codebook_tl.sum(-1)

# one-body potential
with torch.no_grad():
    one_body_props = one_body_potential.get_one_body_potential_props(
        n_types_b=torch.tensor(
            experiment_spec.n_types, device=device_gen, dtype=dtype).expand([experiment_spec.n_types]),
        code_length_b=torch.tensor(
            experiment_spec.code_length, device=device_gen, dtype=dtype).expand([experiment_spec.n_types]),
        type_rank_b=torch.arange(
            experiment_spec.n_types, device=device_gen, dtype=dtype) / experiment_spec.n_types,
        pi_b=torch.tensor(
            experiment_spec.pi_t, device=device_gen, dtype=dtype),
        pi_cdf_b=torch.cumsum(torch.tensor(
            experiment_spec.pi_t, device=device_gen, dtype=dtype), -1))

# make figure
fig, axs = plt.subplots(nrows=3, figsize=(14, 7))

ax = axs[0]
ax.imshow(codebook_tl.T,  cmap=plt.cm.gray)
ax.set_xlabel('Rank (descending source prior)')
ax.set_ylabel('Code')
ax.set_xticks([])
ax.set_xlim((-0.5, experiment_spec.n_types - 0.5))
    
ax = axs[1]
ax.plot(hamming_weight_t)
ax.plot(experiment_spec.code_length * one_body_props['symbol_weights_bs'][:, 1].cpu().numpy())
ax.set_xlabel('Rank (descending source prior)')
ax.set_ylabel('Hamming Weight')
ax.set_ylim((0, experiment_spec.code_length))
ax.set_xticks([])
ax.set_xlim((-0.5, experiment_spec.n_types - 0.5))

ax2 = ax.twinx()
ax2.plot(one_body_props['potential_strength_b'].cpu().numpy(), color='red')
ax2.set_xlabel('Rank (descending source prior)')
ax2.set_ylabel('Strength', color='red')
ax2.set_xlim((-0.5, experiment_spec.n_types - 0.5))

ax = axs[2]
ax.bar(np.arange(experiment_spec.n_types), metrics_dict['fdr_bt'][idx, :].cpu().numpy())
ax.set_xlabel('Rank (descending source prior)')
ax.set_ylabel('FDR')
ax.set_xlim((-0.5, experiment_spec.n_types - 0.5))

fig.tight_layout()

## Compare to SA

In [None]:
import pandas as pd
import pickle

# load SA data
root_path = '/home/jupyter/mb-ml-dev-disk/MightyCodes/notebooks/terra_analysis'
sa_bac_df = pd.read_csv(os.path.join(root_path, 'sa_bac_df.tsv'), delimiter='\t')

In [None]:
lowest_energy_var = []

for i, row in sa_bac_df.iterrows():
    print(f'{i} ...')
    
    with torch.no_grad():
    
        # generate experiment
        experiment_spec = experiments.generate_experiment_spec(
            name_prefix='test',
            min_symbol_weight_s=[row['min_hamming_weight'], row['min_hamming_weight']],
            max_symbol_weight_s=[row['max_hamming_weight'], row['max_hamming_weight']],
            n_symbols=2,
            min_code_length=row['code_length'],
            max_code_length=row['code_length'],
            min_n_types=row['n_types'],
            max_n_types=row['n_types'],
            min_source_nonuniformity=row['source_nonuniformity'],
            max_source_nonuniformity=row['source_nonuniformity'])

        # generate problem spec
        problem_spec = experiments.SingleEntityCodingProblemSpecification(
            experiment_spec=experiment_spec,
            channel_spec=channel_spec)

        # sample a batch of codebooks
        codebook_generator_output_dict = codebook_generator.forward(
            code_length=experiment_spec.code_length,
            n_types=experiment_spec.n_types,
            batch_size=1,
            pi_t=experiment_spec.pi_t,
            nu_tj=None,
            disable_two_body=False,
            two_body_max_n_interactions=row['n_types'],
            two_body_dropout_policy='random',
            top_proposals_per_column=1024,
            random_proposals_per_column=1024,
            enable_hardcore_two_body_potential=True,
            action_policy='greedy',
            action_policy_epsilon=0.,
            top_k_noise_std=0.)

        # get the decoding confusion matrix
        decoder_output_dict = channel_model.get_weighted_confusion_matrix(
            codebook_btls=codebook_generator_output_dict['codebook_btls'].to(device_eval),
            pi_bt=to_torch(experiment_spec.pi_t, device=device_eval, dtype=dtype).expand(
                [1, experiment_spec.n_types]),
            decoder_type=params['decoder_type'],
            **params['decoder_kwargs'])

        # get metrics dict
        metrics_dict = metric_utils.get_metrics_dict_from_decoder_output_dict(
            decoder_output_dict=decoder_output_dict,
            metrics_dict_type=params['metrics_dict_type'],
            **params['metrics_kwargs'])

        # reduce metrics dict to optimality
        optimality_b = metric_utils.get_optimality_from_metrics_dict(
            metrics_dict=metrics_dict,
            optimality_type=params['optimality_type'])
        
        # add to list
        optimality_compl = 1. - optimality_b.item()
        lowest_energy_var.append(optimality_compl)

In [None]:
sa_bac_df['lowest_energy_var'] = np.asarray(lowest_energy_var)
sa_bac_df.to_csv(os.path.join(root_path, 'sa_bac_w_var_df.tsv'), sep='\t', index=False)

## Optimal codes

In [None]:
metrics_dict = estimate_codebook_auc_f1_reject(
    model=model,
    c_btls=output_dict['c_btls'].to(device_eval),
    pi_bt=pi_t.expand((1, n_types)).to(device_eval),
    n_samples_per_type=100_000,
    n_map_reject_thresholds=50,
    delta_q_max=1e-4,
    max_rej_ratio=max_rej_ratio,
    max_n_samples_per_type_per_sampling_round=100,
    return_confusion_matrix=True,
    device=device_eval,
    dtype=dtype)

In [None]:
viz_metrics_dict(metrics_dict, 0, x_key='clamped_rej', y_key='clamped_f_1', x_label='Rejection Rate', y_label='$F_1$ Score')
plt.xlim((0, 0.25))
plt.ylim((0.85, 1.01))
plt.tight_layout()
plt.savefig('./figs/pi_nu_300_700__opt_f1_score.png', dpi=200)

In [None]:
fig_1, fig_2, fig_3 = viz_binary_codebook_and_metrics(metrics_dict, 0, output_dict['c_btls'], metric_key='fdr', figsize=(12, 3))

fig_1.gca().set_ylabel('Code')
fig_1.gca().set_xlabel('Rank')
fig_1.gca().set_yticks([])
fig_1.tight_layout()
fig_1.savefig('./figs/pi_nu_300_700__opt_code.png', dpi=300)

fig_3.gca().set_ylabel('FDR')
fig_3.gca().set_ylim((0, 0.1))
fig_3.tight_layout()
fig_3.savefig('./figs/pi_nu_300_700__opt_fdr.png')

In [None]:
metrics_dict['normalized_auc_f_1_rej_bt'].mean(-1)

In [None]:
c_tl = output_dict['c_btls'][0, :, :, 1]
z_tl = torch.cumsum(c_tl, dim=0)
plt.figure(figsize=(16, 5))
plt.imshow(to_np(z_tl).T)

## MHD4 Codes

In [None]:
# load MHD4 codebook
MERFISH_coding_data_root = '/home/jupyter/mb-ml-data/MERFISH-coding'
mhd4_codebook_krc = np.load(
    os.path.join(MERFISH_coding_data_root, 'MERFISH_2016_U2OS_1_codebook_krc.npy'))
n_codes_mhd4 = mhd4_codebook_krc.shape[0]
mhd4_codebook_tl = mhd4_codebook_krc.reshape(n_codes_mhd4, -1)

# one-hot representation
mhd4_codebook_tls = np.eye(2)[mhd4_codebook_tl.flatten(), :].reshape(mhd4_codebook_tl.shape + (2,))
mhd4_c_btls = to_torch(
    mhd4_codebook_tls[np.random.permutation(n_codes_mhd4), :, :][None, ...],
    device=device_eval, dtype=dtype)

mhd4_metrics_dict = estimate_codebook_auc_f1_reject(
    model=model,
    c_btls=mhd4_c_btls,
    pi_bt=pi_t.expand((1, n_types)),
    n_samples_per_type=50_000,
    n_map_reject_thresholds=50,
    delta_q_max=delta_q_max,
    max_rej_ratio=max_rej_ratio,
    max_n_samples_per_type_per_sampling_round=100,
    return_confusion_matrix=False,
    device=device_eval,
    dtype=dtype)

In [None]:
viz_metrics_dict(mhd4_metrics_dict, 0, x_key='clamped_rej', y_key='clamped_f_1', x_label='Rejection Rate', y_label='$F_1$ Score')
plt.xlim((0, 0.25))
plt.ylim((0.85, 1.01))
plt.tight_layout()
plt.savefig('./figs/pi_nu_300__mhd4_f1_score.png', dpi=200)

In [None]:
mhd4_metrics_dict['normalized_auc_f_1_rej_bt'].mean(-1)

In [None]:
fig_1, fig_2, fig_3 = viz_binary_codebook_and_metrics(mhd4_metrics_dict, 0, mhd4_c_btls, metric_key='fdr', figsize=(12, 3))

fig_1.gca().set_ylabel('Code')
fig_1.gca().set_xlabel('Rank')
fig_1.gca().set_yticks([])
fig_1.tight_layout()
fig_1.savefig('./figs/pi_nu_300__mhd4_code.png')

fig_3.gca().set_ylabel('FDR')
fig_3.gca().set_ylim((0, 0.1))
fig_3.tight_layout()
fig_3.savefig('./figs/pi_nu_300__mhd4_fdr.png')

In [None]:
c_tl = mhd4_c_btls[0, :, :, 1]
z_tl = torch.cumsum(c_tl, dim=0)
plt.figure(figsize=(16, 5))
plt.imshow(to_np(z_tl).T)

## Extended plots

In [None]:
plt.figure(figsize=(12, 3))
barlist = plt.bar(np.arange(n_types) + 1, to_np(pi_t), width=0.5)
red = np.asarray([1., 0., 0.])
blue = np.asarray([0., 0., 1.])
for i_t, bar in enumerate(barlist):
    color = red * (1 - i_t / (n_types - 1)) + blue * i_t / (n_types - 1)
    bar.set_color(color)
plt.yscale('log')
plt.xlabel('Rank')
plt.ylabel('Source Prior')
plt.xlim((1, n_types + 1))
plt.tight_layout()
plt.savefig('./figs/pi_nu_300_700_prior.png', dpi=200)

In [None]:
w = metrics_dict['weighted_confusion_matrix_bqtu'][0, 10, :, :]
w = w / w.sum(-1, keepdim=True)
plt.imshow(to_np(w), vmax=0.5, cmap=plt.cm.gray)
plt.xticks([])
plt.yticks([])