In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from main import read_in_blicks, BOUNDARY, eval_auc
from run_english_batch import *
import scorers
import datasets
import informants
import learners
from tqdm import tqdm

from util import entropy, kl_bern

In [3]:
import pandas as pd

In [4]:
feature_type = 'english'
SEED=1

In [5]:
informant, mf_scorer = load_informant_scorer(feature_type)

Loading lexicon from:	data/hw/english_lexicon.txt
Loading lexicon with min_length=2, max_length=5...
Reading phoneme features from: data/hw/english_features.txt
# features:  54872
feature type:  english
Reading phoneme features from: data/hw/english_features.txt
Loading ngram features from: data/hw/english_feature_weights.txt


In [6]:
data_dir = 'data/BabbleRandomStringsEnglish'
# train_random_strings = load_train_dataset(f'{data_dir}/RandomStringsSubsampledBalanced.csv')
# train_random_wellformed = load_train_dataset(f'{data_dir}/RandomWellFormedSyllablesSubsampledBalanced.csv')
# train_babbled = load_train_dataset(f'data/MakingOverTrainSet/EnglishOverTrainingData.csv', informant, mf_scorer)


## Set random seeds

In [7]:
import random
import numpy as np


def set_seeds(seed, dataset):
    random.seed(seed)
    np.random.seed(seed)
    dataset.random.seed(seed)

## Load eval dataset

In [8]:
eval_dataset = load_eval_dataset(informant, mf_scorer)
display(eval_dataset)

 80%|████████  | 9928/12390 [00:58<00:24, 102.38it/s]

## Initialize learner

## Main loop

In [15]:
def get_auc(scorer, eval_dataset, length_norm = False):
    # Learner.cost() is used to get predictions for the test set
    costs = [scorer.cost(encod, length_norm = length_norm) for encod in eval_dataset['encoded'].values]
    auc = eval_auc(costs, eval_dataset['label'].values)
    return auc

In [16]:
TRAIN_CACHES = {}


In [17]:
import os
import wandb

def initialize_hyp(lla, prior, tol, max_updates, dataset, phoneme_feature_path):
    print("Initializing learner...")
    # You may also have to create a slightly modified learner class to wrap around your linear model scorer
    scorer = scorers.MeanFieldScorer(dataset, 
                                     log_log_alpha_ratio=lla,
                                     prior_prob=prior,
                                     feature_type=feature_type,
                                     tolerance=tol,
                                     phoneme_feature_file=phoneme_feature_path,
                                   )
    
    
    
    return scorer

def run(lla, prior, max_updates, tol, train_file, eval_dataset, phoneme_feature_path,
        out_dir='big_batch', 
        wandb_project='1114_big_batch', 
        num_samples=None,
       ):
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
        
    config = {
        'log_log_alpha_ratio': lla,
        'prior_prob': prior,
        'max_updates': max_updates,
        'tolerance': tol,
        'train_file': train_file,
        'num_samples': num_samples,
    }
    
    wandb_run = wandb.init(config=config, project=wandb_project, reinit=True, entity='lm-informants')

    print(config)
    
    global TRAIN_CACHES
    print('train caches:', TRAIN_CACHES)

    if train_file in TRAIN_CACHES:
        train_dataset = TRAIN_CACHES[train_file]
    else:
        train_dataset = load_train_dataset(train_file, informant, mf_scorer)
        TRAIN_CACHES[train_file] = train_dataset
        
    print('train caches:', TRAIN_CACHES)
        
    sub_dir = os.path.join(out_dir, f'lla={lla}_prior={prior}_max-updates={max_updates}_tol={tol}_num-samples={num_samples}')
    if not os.path.exists(sub_dir):
        os.mkdir(sub_dir)

    scorer = initialize_hyp(lla, prior, tol, max_updates, informant.dataset, phoneme_feature_path)
    
    print('avg prior:', scorer.probs.mean())
    
    ordered_feats = train_dataset['featurized'].values
    ordered_judgments = train_dataset['label'].values
    ordered_judgments = [1 if j else -1 for j in ordered_judgments]
    
    if num_samples is not None:
        # TODO: setting seed of informant.dataset; is that what we want? 
        # (I think it would only matter for learner.dataset, for getting a train candidate, which we're not using here)
        set_seeds(SEED, informant.dataset)
        ordered_feats, ordered_judgments = zip(*random.sample(list(zip(ordered_feats, ordered_judgments)), num_samples))

    print('# data:', len(ordered_feats))
        
    # TODO: setting seed of informant.dataset; is that what we want? 
    # (I think it would only matter for learner.dataset, for getting a train candidate, which we're not using here)
    set_seeds(SEED, informant.dataset)
    
    # Log distribution over train scores
    table_data = [[s] for s in train_dataset['cost'].values]
    table = wandb.Table(data=table_data, columns=["oracle_costs"])
    wandb.log({'train_oracle_costs': wandb.plot.histogram(table, "oracle_costs",
          title="Train: Distribution of oracle costs", num_bins=20)})
    
    # Log distribution over eval scores
    table_data = [[s] for s in eval_dataset['cost'].values]
    print(table_data)
    table = wandb.Table(data=table_data, columns=["oracle_costs"])
    wandb.log({'eval_oracle_costs': wandb.plot.histogram(table, "oracle_costs",
          title="Eval: Distribution of oracle costs", num_bins=20)})

    scorer.update(
        ordered_feats, ordered_judgments, 
        max_updates=max_updates,
        verbose=False)
    
    print("Getting auc...")
    auc = get_auc(scorer, eval_dataset)
    print("Done.")

    print("")
    print(f"auc: {auc}")
    
    print('avg posterior:', scorer.probs.mean())
    
    # Log distribution over learned thetas
    table = wandb.Table(data=[[s] for s in scorer.probs], columns=["prob"])
    wandb.log({'learned_probs': wandb.plot.histogram(table, "prob",
          title="Distribution of learned thetas", num_bins=20)})
    
    probs_file = os.path.join(sub_dir, f'probs.npy')
    # save to probs.py so that it shows up on wandb that way
    np.save('probs.npy', scorer.probs,)
    wandb.save('probs.npy')
    print(f"Writing probs to: {probs_file}")
    # move probs.py file to probs_file
    os.rename('probs.npy', probs_file)
    
    auc_file_name = os.path.join(sub_dir, 'auc.txt')
    print(f'Writing auc to {auc_file_name}')
    print(f'{auc}', file=open(auc_file_name, 'w'))
    
    wandb.log({'auc': auc})

    print("================================")
    
    wandb_run.finish()

In [18]:
def get_probs_file(config, base_dir='big_batch'):
    sub_dir = f'lla={config["log_log_alpha_ratio"]}_prior={config["prior_prob"]}_max-updates={config["max_updates"]}_tol={config["tolerance"]}_num-samples={config["num_samples"]}'
    f = os.path.join(base_dir, sub_dir, 'probs.npy')
    
    return f

def load_probs(run, project_dir='lm-informants/1114_big_batch'):
    f = get_probs_file(run.config)
    print("Loading probs from file at: ", f)
    f_obj = wandb.restore(f, run_path=f'{project_dir}/{run.id}')
    # f_obj is a Text.io object
    # Write the contents of f_obj to temp 'probs.npy' file so that can call np.load() with it
    return np.fromstring(f_obj.read())

In [19]:
import wandb
import pandas as pd
from tqdm import tqdm

def get_wandb_runs():

    # Set your W&B API key (you can find it in your W&B account settings)
    wandb.login()


    # Get all runs from the project
    api = wandb.Api()
    runs = api.runs("lm-informants/1114_big_batch")

    # Initialize lists to store data
    data = []

    # Iterate over runs
    for run in tqdm(runs):
        run_id = run.id

        # Download run artifacts
    #     run.download()

        # Get run metrics
        metrics = run.history()

        # Get run config
        config = run.config

        # Append data to the list
        d = ({
            "run_id": run_id,
            "metrics": metrics,
        })
        d.update(config)
        d['config'] = config
        d.update(run.summary._json_dict)
        # probs = load_probs(run)

        # Check if status succeeded
        if run.state != "finished":
            print(f"Skipping run {run_id} because status is {run.state}")
            continue

        # An alternative way to download npy files
        num_files_found = 0
        root = 'temp_files'
        for f in run.files():
            if f.name.endswith('.npy'):
                # print(f)
                f.download(root=root, replace=True)
                f_path = os.path.join(root, f.name)
                print(f_path)
                num_files_found += 1

                probs = np.load(f_path)
                
                d.update({'probs': probs, 'probs_mean': probs.mean()})
                data.append(d)

        assert num_files_found == 1, f'Found {num_files_found} npy files for run {run_id}'

        

    # Convert the list of dictionaries to a DataFrame
    df = pd.DataFrame(data)

    df = df.dropna(subset=['auc'], axis=0)
    df = df.fillna("None")


    # Save DataFrame to a CSV file
#     df.to_csv("wandb_runs.csv", index=False)

    return df

In [20]:
!rm -r temp_files

In [21]:
df = get_wandb_runs()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malexisjihyeross[0m. Use [1m`wandb login --relogin`[0m to force relogin
  3%|▎         | 1/33 [00:00<00:30,  1.05it/s]

temp_files/big_batch/lla=5.41687946870128_prior=0.00240504883318384_max-updates=None_tol=1.953125e-06_num-samples=None/probs.npy


  6%|▌         | 2/33 [00:01<00:29,  1.04it/s]

temp_files/big_batch/lla=5.41687946870128_prior=0.00240504883318384_max-updates=1_tol=1.953125e-06_num-samples=None/probs.npy


  9%|▉         | 3/33 [00:02<00:28,  1.05it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=None/probs.npy


 12%|█▏        | 4/33 [00:04<00:30,  1.05s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=None/probs.npy


 15%|█▌        | 5/33 [00:05<00:30,  1.08s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=7000/probs.npy


 24%|██▍       | 8/33 [00:06<00:14,  1.74it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=7000/probs.npy
Skipping run ra4xck2j because status is failed
Skipping run b8l4fqzq because status is failed


 27%|██▋       | 9/33 [00:07<00:16,  1.47it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=7000/probs.npy


 30%|███       | 10/33 [00:08<00:16,  1.37it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=1000/probs.npy


 33%|███▎      | 11/33 [00:09<00:17,  1.26it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=1000/probs.npy


 36%|███▋      | 12/33 [00:10<00:18,  1.13it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=1000/probs.npy


 39%|███▉      | 13/33 [00:11<00:18,  1.09it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=5000/probs.npy


 42%|████▏     | 14/33 [00:12<00:17,  1.08it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=5000/probs.npy


 45%|████▌     | 15/33 [00:13<00:17,  1.06it/s]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=5000/probs.npy


 48%|████▊     | 16/33 [00:14<00:17,  1.02s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=None/probs.npy


 52%|█████▏    | 17/33 [00:15<00:16,  1.05s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=None/probs.npy


 55%|█████▍    | 18/33 [00:16<00:16,  1.09s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=None/probs.npy


 58%|█████▊    | 19/33 [00:17<00:15,  1.10s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=10/probs.npy


 61%|██████    | 20/33 [00:19<00:14,  1.13s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=10/probs.npy


 64%|██████▎   | 21/33 [00:20<00:13,  1.09s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=None_tol=1.953125e-06_num-samples=10/probs.npy


 67%|██████▋   | 22/33 [00:21<00:11,  1.07s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=1000/probs.npy


 70%|██████▉   | 23/33 [00:22<00:10,  1.03s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=1000/probs.npy


 73%|███████▎  | 24/33 [00:23<00:09,  1.03s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=1000/probs.npy


 76%|███████▌  | 25/33 [00:24<00:08,  1.03s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=5000/probs.npy


 79%|███████▉  | 26/33 [00:25<00:07,  1.06s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=5000/probs.npy


 82%|████████▏ | 27/33 [00:26<00:06,  1.01s/it]

temp_files/big_batch/lla=0.522731931474557_prior=0.00138533389897108_max-updates=1_tol=1.953125e-06_num-samples=5000/probs.npy


 82%|████████▏ | 27/33 [00:27<00:06,  1.00s/it]


KeyboardInterrupt: 

In [None]:
    
file_map = {
    'WordsAndScoresFixed_newest.csv': 'eval',
    'data/BabbleRandomStringsEnglish/RandomStringsSubsampledBalanced.csv': 'random_strings',
    'data/MakingOverTrainSet/EnglishOverTrainingData.csv': 'hw_babble',
    'data/BabbleRandomStringsEnglish/RandomWellFormedSyllablesSubsampledBalanced.csv': 'random_wellformed',
}

inverse_file_map = {v:k for k,v in file_map.items()}

In [None]:
df['train_file_short'] = df.apply(lambda row: file_map[row['train_file']], axis=1)

In [None]:
df

Unnamed: 0,run_id,metrics,tolerance,prior_prob,train_file,max_updates,num_samples,log_log_alpha_ratio,config,eval_oracle_costs_table,train_oracle_costs_table,auc,_step,_wandb,_runtime,_timestamp,learned_probs_table,probs,probs_mean,train_file_short
0,kpqlkcdq,_timestamp ...,2e-06,0.002405,WordsAndScoresFixed_newest.csv,,,5.416879,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'path': 'media/table/eval_oracle_costs_table_...,"{'size': 104670, '_type': 'table-file', 'ncols...",0.559767,3,{'runtime': 22119},22132.649222,1700268000.0,"{'_type': 'table-file', 'ncols': 1, 'nrows': 5...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.335859,eval
1,esrhwmyt,_runtime _timestamp \ 0 185.9188...,2e-06,0.002405,WordsAndScoresFixed_newest.csv,1.0,,5.416879,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,"{'nrows': 11336, 'sha256': 'b0146e055c677beef0...",0.559768,3,{'runtime': 13411},13412.489867,1700246000.0,"{'nrows': 54872, 'sha256': '33435f47aba1533e68...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.335898,eval
2,260x2m2l,_step _runtime _timestamp \ 0 ...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,,,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,{'path': 'media/table/train_oracle_costs_table...,0.598111,3,{'runtime': 27886},27900.865152,1700189000.0,"{'_type': 'table-file', 'ncols': 1, 'nrows': 5...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230144,eval
3,tku5l5t1,_runtime _timestamp \ 0 184.0696...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,1.0,,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'size': 104670, '_type': 'table-file', 'ncols...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,0.598111,3,{'runtime': 13485},13485.906964,1700161000.0,"{'ncols': 1, 'nrows': 54872, 'sha256': '37d8e3...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230115,eval
4,4r66aqrk,_step _runtime _timestamp \ 0 ...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'ncols': 1, 'nrows': 11336, 'sha256': 'b0146e...","{'ncols': 1, 'nrows': 11336, 'sha256': 'b0146e...",0.606149,3,{'runtime': 15585},15591.264502,1700103000.0,{'path': 'media/table/learned_probs_table_2_13...,"[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.195201,eval
5,4oh9gzmw,eval_oracle_costs...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,1.0,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'nrows': 11336, 'sha256': 'b0146e055c677beef0...",{'artifact_path': 'wandb-client-artifact://lcc...,0.606149,3,{'runtime': 4467},4467.807258,1700087000.0,{'sha256': '1f8c27fdca75ec35faab77986d5808af29...,"[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.195158,eval
6,roaovkp2,auc _step _runtime _timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomStringsS...,1.0,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'artifact_path': 'wandb-client-artifact://fyw...,"{'_type': 'table-file', 'ncols': 1, 'nrows': 8...",0.499816,3,{'runtime': 2911},2912.432593,1700067000.0,"{'ncols': 1, 'nrows': 54872, 'sha256': 'c71e69...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.500644,random_strings
7,kqjx5wf4,auc _step _runtime _timestamp ...,2e-06,0.001385,data/MakingOverTrainSet/EnglishOverTrainingDat...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'_type': 'table-file', 'ncols': 1, 'nrows': 1...","{'size': 60426, '_type': 'table-file', 'ncols'...",0.692479,3,{'runtime': 714},718.528733,1700055000.0,{'sha256': '0503f33364d06d9a1da7fed1e1ebd8a611...,"[0.0885239943723408, 0.054540839935633696, 0.0...",0.004306,hw_babble
8,bo8wlszs,_timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomWellForm...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,{'_latest_artifact_path': 'wandb-client-artifa...,0.500194,3,{'runtime': 510},514.366387,1700055000.0,"{'size': 203550, '_type': 'table-file', 'ncols...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.209299,random_wellformed
9,105c3niz,auc _step _runtime _timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomStringsS...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'artifact_path': 'wandb-client-artifact://fv9...,{'artifact_path': 'wandb-client-artifact://3sb...,0.499887,3,{'runtime': 398},402.580664,1700054000.0,"{'nrows': 54872, 'sha256': 'f5d51a30a4c3a9a77d...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230404,random_strings


In [None]:
indices_to_feats = {feat_idx: mf_scorer.feature_vocab.decode(feat) for feat, feat_idx in mf_scorer.ngram_features.items()}
indices_to_feats_encoded = {feat_idx: feat for feat, feat_idx in mf_scorer.ngram_features.items()}


In [None]:
import matplotlib.pyplot as plt

def plot_learned_weights_compare(probs, title=None):
    # TODO: For the trigram features that have *multiple* phonemes in the first slot, should we break that into two separate trigrams?
    """ Plots the values of the learned feature weights probs against the true oracle weights 
    """

    learned_weights = []
    oracle_weights = []

    # Define dictionary mapping from feature to index for learned features
    learned_features_to_weights = {}
    for feat_idx, weight in enumerate(probs):
        decoded_feat = indices_to_feats[feat_idx]
        learned_features_to_weights[decoded_feat] = weight

    # Iterate througph probs, and if the feature is in the oracle, append to learned_weights
    for (feat_template, weight) in informant.scorer.ngram_features[3]:
        # print(feat_template)
        # print(weight)
        decoded = (informant.scorer.pp_feature(feat_template))
        # print()

        decoded_tuple = tuple([d[1:-1] for d in decoded.split(' ')])
        if len(decoded_tuple) != 3:
            # print("continuing because len != 3")
            # print(decoded)
            # print(decoded_tuple)
            continue
        else:
            # print(decoded_tuple)
            # print(learned_features_to_weights.keys())
            if decoded_tuple in learned_features_to_weights:
                learned_weights.append(learned_features_to_weights[decoded_tuple])
                oracle_weights.append(weight)
    
    # Plot
    print(f"# overlapping weights: {len(learned_weights)}/{len(informant.scorer.ngram_features[3])+len(informant.scorer.ngram_features[2])}")
    plt.scatter(oracle_weights, learned_weights, alpha=0.5)
    plt.xlabel("Oracle weights")
    plt.ylabel("Learned weights")
    
    if title is None:
        title = ""
    title += ("Oracle weights vs. learned weights")

    plt.title(title)
    plt.show()
    

In [None]:
def plot_learned_weights(probs, title=None):
    """ Plot the distribution of learned feature weights in probs, but also show which features are in the oracle.
    Plot two histograms: one for the learned features that are in the oracle features, and one for the learned features that are not in the oracle features.
    """

    feats_in_oracle = []
    feats_not_in_oracle = []

    # Define dictionary mapping from feature to weight for oracle features
    oracle_features_to_weights = {}
    for (feat_template, weight) in informant.scorer.ngram_features[3]:
        decoded = (informant.scorer.pp_feature(feat_template))
        decoded_tuple = tuple([d[1:-1] for d in decoded.split(' ')])
        if len(decoded_tuple) != 3:
            # print("continuing because len != 3")
            # print(decoded)
            # print(decoded_tuple)
            continue
        else:
            oracle_features_to_weights[decoded_tuple] = weight

    for feat_idx, weight in enumerate(probs):
        decoded_feat = indices_to_feats[feat_idx]

        # Figure out if in oracle
        if decoded_feat in oracle_features_to_weights:
            feats_in_oracle.append(weight)
        else:
            feats_not_in_oracle.append(weight)


    # Plot
    # Create subplots, one for features in oracle, one for features not in oracle
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].hist(feats_in_oracle, bins=40)
    axs[0].set_title("Learned weights for features in oracle")
    axs[0].set_xlabel("Learned weights")
    axs[0].set_ylabel("Frequency")
    # annotate the bars of the histogram with values
    for rect in axs[0].patches:
        height = rect.get_height()
        if height > 0:
            axs[0].annotate(f'{height:.2f}', xy=(rect.get_x() + rect.get_width() / 2, height), 
            xytext=(0, 3), textcoords="offset points", ha='center', va='bottom',
            rotation=90)

    axs[0].set_ylim([0, 15])

    axs[1].hist(feats_not_in_oracle, bins=20)
    axs[1].set_title("Learned weights for features not in oracle")
    axs[1].set_xlabel("Learned weights")
    axs[1].set_ylabel("Frequency")
    # annotate the bars of the histogram with values
    for rect in axs[1].patches:
        height = rect.get_height()
        if height > 0:
            axs[1].annotate(f'{height:.2f}', xy=(rect.get_x() + rect.get_width() / 2, height), xytext=(0, 3), 
            textcoords="offset points", ha='center', va='bottom', rotation=90)
    axs[1].set_ylim([0, 55000])

    if title is None:
        title = ""
    title += ("Distribution of learned weights")
    fig.suptitle(title, fontsize=16, y=1.03)

    plt.show()

    

In [None]:
def plot_eval_costs_by_label(probs, config, eval_dataset, title=None, costs=None):

    if costs is None:
        # create a dummy scorer
        scorer = initialize_hyp(lla, prior_prob, tol, max_updates, informant.dataset, 'data/hw/english_features.txt')
        scorer.probs = probs
        # get costs for eval items
        costs = [scorer.cost(encod) for encod in eval_dataset['encoded'].values]

    # plot two histograms in two subplots
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    # plot histogram of costs for items labeled 1
    costs_1 = [costs[i] for i, label in enumerate(eval_dataset['label'].values) if label == 1]
    axs[0].hist(costs_1, bins=20)
    axs[0].set_title("Eval costs for items labeled 1")
    axs[0].set_xlabel("Eval costs")
    axs[0].set_ylabel("Frequency")
    axs[0].set_ylim([0, 1700])
    # annotate the bars of the histogram with values
    for rect in axs[0].patches:
        height = rect.get_height()
        if height > 0:
            axs[0].annotate(f'{height:.2f}', xy=(rect.get_x() + rect.get_width() / 2, height), 
            xytext=(0, 3), textcoords="offset points", ha='center', va='bottom',
            rotation=90)

    # plot histogram of costs for items labeled 0
    costs_0 = [costs[i] for i, label in enumerate(eval_dataset['label'].values) if label == 0]
    axs[1].hist(costs_0, bins=20)
    axs[1].set_title("Eval costs for items labeled 0")
    axs[1].set_xlabel("Eval costs")
    axs[1].set_ylabel("Frequency")
    axs[1].set_ylim([0, 700])
    # annotate the bars of the histogram with values
    for rect in axs[1].patches:
        height = rect.get_height()
        if height > 0:
            axs[1].annotate(f'{height:.2f}', xy=(rect.get_x() + rect.get_width() / 2, height), 
            xytext=(0, 3), textcoords="offset points", ha='center', va='bottom',
            rotation=90)

    if title is None:
        title = ""
    title += (f"Eval costs for items labeled 1 vs. 0")
    fig.suptitle(title, fontsize=16)
    plt.show()

In [None]:
display(df)

Unnamed: 0,run_id,metrics,tolerance,prior_prob,train_file,max_updates,num_samples,log_log_alpha_ratio,config,eval_oracle_costs_table,train_oracle_costs_table,auc,_step,_wandb,_runtime,_timestamp,learned_probs_table,probs,probs_mean,train_file_short
0,kpqlkcdq,_timestamp ...,2e-06,0.002405,WordsAndScoresFixed_newest.csv,,,5.416879,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'path': 'media/table/eval_oracle_costs_table_...,"{'size': 104670, '_type': 'table-file', 'ncols...",0.559767,3,{'runtime': 22119},22132.649222,1700268000.0,"{'_type': 'table-file', 'ncols': 1, 'nrows': 5...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.335859,eval
1,esrhwmyt,_runtime _timestamp \ 0 185.9188...,2e-06,0.002405,WordsAndScoresFixed_newest.csv,1.0,,5.416879,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,"{'nrows': 11336, 'sha256': 'b0146e055c677beef0...",0.559768,3,{'runtime': 13411},13412.489867,1700246000.0,"{'nrows': 54872, 'sha256': '33435f47aba1533e68...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.335898,eval
2,260x2m2l,_step _runtime _timestamp \ 0 ...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,,,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,{'path': 'media/table/train_oracle_costs_table...,0.598111,3,{'runtime': 27886},27900.865152,1700189000.0,"{'_type': 'table-file', 'ncols': 1, 'nrows': 5...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230144,eval
3,tku5l5t1,_runtime _timestamp \ 0 184.0696...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,1.0,,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'size': 104670, '_type': 'table-file', 'ncols...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,0.598111,3,{'runtime': 13485},13485.906964,1700161000.0,"{'ncols': 1, 'nrows': 54872, 'sha256': '37d8e3...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230115,eval
4,4r66aqrk,_step _runtime _timestamp \ 0 ...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'ncols': 1, 'nrows': 11336, 'sha256': 'b0146e...","{'ncols': 1, 'nrows': 11336, 'sha256': 'b0146e...",0.606149,3,{'runtime': 15585},15591.264502,1700103000.0,{'path': 'media/table/learned_probs_table_2_13...,"[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.195201,eval
5,4oh9gzmw,eval_oracle_costs...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,1.0,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'nrows': 11336, 'sha256': 'b0146e055c677beef0...",{'artifact_path': 'wandb-client-artifact://lcc...,0.606149,3,{'runtime': 4467},4467.807258,1700087000.0,{'sha256': '1f8c27fdca75ec35faab77986d5808af29...,"[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.195158,eval
6,roaovkp2,auc _step _runtime _timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomStringsS...,1.0,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'artifact_path': 'wandb-client-artifact://fyw...,"{'_type': 'table-file', 'ncols': 1, 'nrows': 8...",0.499816,3,{'runtime': 2911},2912.432593,1700067000.0,"{'ncols': 1, 'nrows': 54872, 'sha256': 'c71e69...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.500644,random_strings
7,kqjx5wf4,auc _step _runtime _timestamp ...,2e-06,0.001385,data/MakingOverTrainSet/EnglishOverTrainingDat...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'_type': 'table-file', 'ncols': 1, 'nrows': 1...","{'size': 60426, '_type': 'table-file', 'ncols'...",0.692479,3,{'runtime': 714},718.528733,1700055000.0,{'sha256': '0503f33364d06d9a1da7fed1e1ebd8a611...,"[0.0885239943723408, 0.054540839935633696, 0.0...",0.004306,hw_babble
8,bo8wlszs,_timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomWellForm...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,{'_latest_artifact_path': 'wandb-client-artifa...,0.500194,3,{'runtime': 510},514.366387,1700055000.0,"{'size': 203550, '_type': 'table-file', 'ncols...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.209299,random_wellformed
9,105c3niz,auc _step _runtime _timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomStringsS...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'artifact_path': 'wandb-client-artifact://fv9...,{'artifact_path': 'wandb-client-artifact://3sb...,0.499887,3,{'runtime': 398},402.580664,1700054000.0,"{'nrows': 54872, 'sha256': 'f5d51a30a4c3a9a77d...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230404,random_strings


In [None]:
def get_exp(df, lla, prior, max_updates, train_file_short, num_samples=None):
    temp = df[
        (df['log_log_alpha_ratio']==lla) & 
        (df['prior_prob']==prior) & 
        (df['max_updates']==max_updates) & 
        (df['train_file_short']==train_file_short) & 
        (df['num_samples']==num_samples)
    ]

    display(temp)

    assert len(temp) == 1
    print("AUC:", temp['auc'])

    return temp.iloc[0]

def print_dict(d):
    """ Pretty print a dictionary """
    for k, v in d.items():
        print(f"{k}: {v}")


# eval_row = df[df['run_id']=='4r66aqrk'].iloc[0]
# print_dict(eval_row['config'])

# print()
# print(eval_row['train_file_short'])
# print(eval_row['probs_mean'])

In [None]:
df['train_file_short'].unique()

array(['eval', 'random_strings', 'hw_babble', 'random_wellformed'],
      dtype=object)

In [None]:
lla = 0.522731931474557
# lla = 5.41687946870128
# prior_prob = 0.00240504883318384
prior_prob = 0.00138533389897108
max_updates=1
# 'None'
# tol = 0.000002
train_file_short = 'eval'
# train_file_short = 'random_strings'
# num_samples=10
num_samples='None'
exp = get_exp(df, lla, prior_prob, max_updates, train_file_short, num_samples=num_samples)
# print(exp['probs'].values[0])

Unnamed: 0,run_id,metrics,tolerance,prior_prob,train_file,max_updates,num_samples,log_log_alpha_ratio,config,eval_oracle_costs_table,train_oracle_costs_table,auc,_step,_wandb,_runtime,_timestamp,learned_probs_table,probs,probs_mean,train_file_short
3,tku5l5t1,_runtime _timestamp \ 0 184.0696...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,1.0,,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'size': 104670, '_type': 'table-file', 'ncols...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,0.598111,3,{'runtime': 13485},13485.906964,1700161000.0,"{'ncols': 1, 'nrows': 54872, 'sha256': '37d8e3...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230115,eval


AUC: 3    0.598111
Name: auc, dtype: float64


In [None]:
length_norm = True
for train_file_short in ['eval', 'hw_babble']:
    

    exp = get_exp(df, lla, prior_prob, max_updates, train_file_short, num_samples=num_samples)
    probs_1 = exp['probs']
    
    scorer = initialize_hyp(lla, prior_prob, tol, max_updates, informant.dataset, 'data/hw/english_features.txt')
    scorer.probs = probs_1
    costs = [scorer.cost(encod, length_norm=length_norm) for encod in eval_dataset['encoded'].values]
    auc = get_auc(scorer, eval_dataset, length_norm=length_norm)
    assert auc == exp['auc'], f'{auc} != {exp["auc"]}'
    title = f'train file: {train_file_short} (auc = {round(auc, 3)})\n'

    print(probs_1)
    print(probs_1.mean())

    plot_learned_weights_compare(probs_1, title=title)
    plot_learned_weights(probs_1, title=title)
    plot_eval_costs_by_label(probs_1, exp['config'], eval_dataset, title=title, costs=costs)

Unnamed: 0,run_id,metrics,tolerance,prior_prob,train_file,max_updates,num_samples,log_log_alpha_ratio,config,eval_oracle_costs_table,train_oracle_costs_table,auc,_step,_wandb,_runtime,_timestamp,learned_probs_table,probs,probs_mean,train_file_short
3,tku5l5t1,_runtime _timestamp \ 0 184.0696...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,1.0,,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'size': 104670, '_type': 'table-file', 'ncols...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,0.598111,3,{'runtime': 13485},13485.906964,1700161000.0,"{'ncols': 1, 'nrows': 54872, 'sha256': '37d8e3...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230115,eval


AUC: 3    0.598111
Name: auc, dtype: float64
Initializing learner...
Reading phoneme features from: data/hw/english_features.txt
# features:  54872


TypeError: cost() got an unexpected keyword argument 'length_norm'

In [None]:
df

Unnamed: 0,run_id,metrics,tolerance,prior_prob,train_file,max_updates,num_samples,log_log_alpha_ratio,config,eval_oracle_costs_table,train_oracle_costs_table,auc,_step,_wandb,_runtime,_timestamp,learned_probs_table,probs,probs_mean,train_file_short
0,kpqlkcdq,_timestamp ...,2e-06,0.002405,WordsAndScoresFixed_newest.csv,,,5.416879,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'path': 'media/table/eval_oracle_costs_table_...,"{'size': 104670, '_type': 'table-file', 'ncols...",0.559767,3,{'runtime': 22119},22132.649222,1700268000.0,"{'_type': 'table-file', 'ncols': 1, 'nrows': 5...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.335859,eval
1,esrhwmyt,_runtime _timestamp \ 0 185.9188...,2e-06,0.002405,WordsAndScoresFixed_newest.csv,1.0,,5.416879,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,"{'nrows': 11336, 'sha256': 'b0146e055c677beef0...",0.559768,3,{'runtime': 13411},13412.489867,1700246000.0,"{'nrows': 54872, 'sha256': '33435f47aba1533e68...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.335898,eval
2,260x2m2l,_step _runtime _timestamp \ 0 ...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,,,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,{'path': 'media/table/train_oracle_costs_table...,0.598111,3,{'runtime': 27886},27900.865152,1700189000.0,"{'_type': 'table-file', 'ncols': 1, 'nrows': 5...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230144,eval
3,tku5l5t1,_runtime _timestamp \ 0 184.0696...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,1.0,,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'size': 104670, '_type': 'table-file', 'ncols...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,0.598111,3,{'runtime': 13485},13485.906964,1700161000.0,"{'ncols': 1, 'nrows': 54872, 'sha256': '37d8e3...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230115,eval
4,4r66aqrk,_step _runtime _timestamp \ 0 ...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'ncols': 1, 'nrows': 11336, 'sha256': 'b0146e...","{'ncols': 1, 'nrows': 11336, 'sha256': 'b0146e...",0.606149,3,{'runtime': 15585},15591.264502,1700103000.0,{'path': 'media/table/learned_probs_table_2_13...,"[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.195201,eval
5,4oh9gzmw,eval_oracle_costs...,2e-06,0.001385,WordsAndScoresFixed_newest.csv,1.0,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'nrows': 11336, 'sha256': 'b0146e055c677beef0...",{'artifact_path': 'wandb-client-artifact://lcc...,0.606149,3,{'runtime': 4467},4467.807258,1700087000.0,{'sha256': '1f8c27fdca75ec35faab77986d5808af29...,"[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.195158,eval
6,roaovkp2,auc _step _runtime _timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomStringsS...,1.0,7000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'artifact_path': 'wandb-client-artifact://fyw...,"{'_type': 'table-file', 'ncols': 1, 'nrows': 8...",0.499816,3,{'runtime': 2911},2912.432593,1700067000.0,"{'ncols': 1, 'nrows': 54872, 'sha256': 'c71e69...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.500644,random_strings
7,kqjx5wf4,auc _step _runtime _timestamp ...,2e-06,0.001385,data/MakingOverTrainSet/EnglishOverTrainingDat...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...","{'_type': 'table-file', 'ncols': 1, 'nrows': 1...","{'size': 60426, '_type': 'table-file', 'ncols'...",0.692479,3,{'runtime': 714},718.528733,1700055000.0,{'sha256': '0503f33364d06d9a1da7fed1e1ebd8a611...,"[0.0885239943723408, 0.054540839935633696, 0.0...",0.004306,hw_babble
8,bo8wlszs,_timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomWellForm...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'sha256': 'b0146e055c677beef03e6b4ec0b4fb9b79...,{'_latest_artifact_path': 'wandb-client-artifa...,0.500194,3,{'runtime': 510},514.366387,1700055000.0,"{'size': 203550, '_type': 'table-file', 'ncols...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.209299,random_wellformed
9,105c3niz,auc _step _runtime _timestamp ...,2e-06,0.001385,data/BabbleRandomStringsEnglish/RandomStringsS...,,1000.0,0.522732,"{'tolerance': 1.953125e-06, 'prior_prob': 0.00...",{'artifact_path': 'wandb-client-artifact://fv9...,{'artifact_path': 'wandb-client-artifact://3sb...,0.499887,3,{'runtime': 398},402.580664,1700054000.0,"{'nrows': 54872, 'sha256': 'f5d51a30a4c3a9a77d...","[0.99999, 0.99999, 0.99999, 0.99999, 0.99999, ...",0.230404,random_strings


In [251]:
for p1, p2 in zip(probs_1, probs_2):
    if p1 != p2:
        print(p1, p2)

In [210]:
eval_row = df[df['run_id']=='kpqlkcdq'].iloc[0]
print(eval_row['config'])
print(eval_row['train_file_short'])

{'tolerance': 1.953125e-06, 'prior_prob': 0.00240504883318384, 'train_file': 'WordsAndScoresFixed_newest.csv', 'max_updates': None, 'num_samples': None, 'log_log_alpha_ratio': 5.41687946870128}
eval


In [None]:
for feat_idx, weight in enumerate(eval_row['probs']):
    print(feat_idx, round(weight, 4), indices_to_feats[feat_idx], indices_to_feats_encoded[feat_idx])
    

In [47]:
for (feat_template, weight) in informant.scorer.ngram_features[2]:
    print(feat_template)
    print(weight)
    print(informant.scorer.pp_feature(feat_template))
    print()
    

((6,), (10,))
1.967
[+continuant] [+voice]

((3, 31), (3, 1))
3.324
[-approximant, -tense] [-approximant, -consonantal]

((3, 31), (33,))
1.802
[-approximant, -tense] [-low]

((11,), (10,))
2.787
[-voice] [+voice]

((3, 28), (3, 14, 1))
1.053
[-approximant, +back] [-approximant, +labial, -consonantal]

((10, 16), (20,))
2.824
[+voice, +coronal] [+strident]

((6,), (6, 20))
3.679
[+continuant] [+continuant, +strident]

((32,), (26, 28))
0.449
[+low] [+high, +back]

((7, 16), (7,))
3.275
[-continuant, +coronal] [-continuant]

((6, 10), (5, 16))
0.054
[+continuant, +voice] [-sonorant, +coronal]

((6, 10), (11,))
3.031
[+continuant, +voice] [-voice]

((3, 19), (0,))
4.383
[-approximant, -anterior] [+consonantal]

((36,), (3, 26, 28))
2.826
[+word_boundary] [-approximant, +high, +back]

((3, 1), (3, 26, 28))
1.693
[-approximant, -consonantal] [-approximant, +high, +back]

((32,), (3, 14, 26))
0.81
[+low] [-approximant, +labial, +high]

((3, 27, 29), (30, 33))
0.972
[-approximant, -high, -ba

In [None]:
informant.scorer.pp_feature(((36,), (7,), (3, 0)))

In [None]:
for k, v in informant.scorer.ngram_features.items():
    print(k)
    print(len(v))
    print(v)

In [None]:
print(informant.scorer.ngram_features)

In [None]:
eval_row['probs']

In [None]:
for feat, feat_idx in mf_scorer.ngram_features.items():
    decoded = (mf_scorer.feature_vocab.decode(feat))
    print(feat_idx, decoded)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt



def plot_auc(df, max_updates=1):
    
    m = file_map

    temp = df[df['max_updates']!=max_updates]


    title = f'max_updates=None'

    temp['train_name'] = temp.apply(lambda row: m[row['train_file']], axis=1)

    # Choose the grouping key (either "num_samples" or "train_file")
    grouping_key = "num_samples"

    # Automatically create a color palette based on unique values in the other key
    palette = sns.color_palette("husl", n_colors=len(temp[grouping_key].unique()))

    # Create a bar plot
    ax = sns.barplot(
        x=grouping_key,
        y="auc",
        hue="train_name" if grouping_key == "num_samples" else "num_samples",
        data=temp,
        palette=palette,
        order=[10.0, 1000.0, 5000.0, 7000.0, "None"],  # Specify the order of x-axis values

    )
    ax.legend(bbox_to_anchor=(1.05, 0), loc='lower center', borderaxespad=0.)

    plt.title(title)
    # Show the plot
    plt.show()

In [None]:
df

In [None]:
plot_auc(df)

In [None]:
# Get all runs from the project
api = wandb.Api()
runs = api.runs("lm-informants/1114_big_batch")
run = runs[0]

In [None]:
load_probs(run)

In [None]:
file_name = 'big_batch/lla=5.41687946870128_prior=0.00240504883318384_max-updates=None_tol=1.953125e-06_num-samples=None/probs.npy'
f = wandb.restore(file_name, run_path="lm-informants/1114_big_batch/kpqlkcdq")



In [None]:
file_name

In [None]:
get_probs_file(run.config)

In [None]:
f.name

In [None]:
np.load(f.name)

In [None]:
train_dataset = load_train_dataset('data/MakingOverTrainSet/EnglishOverTrainingData.csv', informant, mf_scorer)
display(train_dataset)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import math

def plot_histogram_scores(df, title=None):

    # Assuming your DataFrame is named df and has a column named 'score'
    # For example, df = pd.DataFrame({'score': [85, 90, 88, 92, 78, 95, 87, 88, 90]})

    bin_size = 0.5
    # Plotting a histogram
    
    min_cost = math.floor(min(df['cost']))
    max_cost = math.ceil(max(df['cost']))
    
    bins=np.arange(min_cost, max_cost + bin_size, bin_size)
    fig = plt.figure(figsize=(len(bins)/2, 2))
    
    plt.hist(df['cost'], bins=bins, color='blue', edgecolor='black')

    # Adding labels and title
    plt.xlabel('Score')
    plt.xticks(bins)
    plt.ylabel('Frequency')
    if title is not None:
        plt.title(title)

    # Display the plot
    plt.show()
    
def plot_histogram_labels(df, title=None):

    # Assuming your DataFrame is named df and has a column named 'score'
    # For example, df = pd.DataFrame({'score': [85, 90, 88, 92, 78, 95, 87, 88, 90]})

    # Count the occurrences of each label
    label_counts = df['label'].value_counts()

    # Plotting a bar plot
    label_counts.plot(kind='bar', edgecolor='black')

    # Adding labels and title
    plt.xlabel('Label')
    plt.ylabel('Count')
    if title is not None:
        plt.title(title)
    # Display the plot
    plt.show()

In [None]:
train_files = [
    'data/MakingOverTrainSet/EnglishOverTrainingData.csv',
    'data/BabbleRandomStringsEnglish/RandomStringsSubsampledBalanced.csv',
    'data/BabbleRandomStringsEnglish/RandomWellFormedSyllablesSubsampledBalanced.csv',    
]

train_datasets = [load_train_dataset(f, informant, mf_scorer) for f in train_files]


In [None]:
for (train_file, train_dataset) in zip(train_files, train_datasets):
#     train_dataset = load_train_dataset(train_file, informant, mf_scorer)
    plot_histogram_scores(train_dataset, title=f'Histogram of Scores:\n{train_file}')
    plot_histogram_labels(train_dataset, title=f'Histogram of Scores:\n{train_file}')
    

In [254]:
eval_dataset = load_eval_dataset(informant, mf_scorer)

100%|██████████| 12390/12390 [00:00<00:00, 1377459.82it/s]
100%|██████████| 12390/12390 [00:00<00:00, 541852.28it/s]
100%|██████████| 12390/12390 [00:16<00:00, 733.66it/s]
100%|██████████| 12390/12390 [00:32<00:00, 382.22it/s]
100%|██████████| 12390/12390 [00:33<00:00, 372.00it/s]


In [None]:
# eval_dataset = load_eval_dataset(informant, mf_scorer)
plot_histogram_scores(eval_dataset, title=f'Histogram of Scores:\nEVAL')
plot_histogram_labels(eval_dataset, title=f'Histogram of Labels:\nEVAL')

In [None]:
eval_dataset['label'].value_counts()

In [None]:
for td in train_datasets:
#     print(td['item'].value_counts().max())
    print(len(td))