In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

from typet5.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()

In [2]:
# experiment configurations

import torch

from typet5.data import (
    TokenizedSrcSet,
    get_dataset_name,
    load_tokenized_srcsets,
    TypeCheckSettings,
)
from copy import copy
from typet5.train import TrainingConfig, TypeCheckArgs
from typet5.type_check import MypyChecker

config = TrainingConfig(quicktest=False, all_labels=True)
train_R1: bool = True
gpu_id = 1
TypeCheckSettings.temp_path = f"GPU-{gpu_id}"

project_name = "test-SPOT" if config.quicktest else "SPOT"

max_tokens_per_file = config.ctx_size

datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)

r0_model_name = "R0-model--" + config._replace(quicktest=False).as_name()

tk_dataset = load_tokenized_srcsets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    quicktest=config.quicktest,
)


  warn(f"Failed to load image Python extension: {e}")


Loading datasets:  tk_dataset-all_labels-drop_comments


In [3]:
# load trained model
from typet5.utils import pickle_load, pickle_dump
from typet5.model import ModelWrapper


r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
# if train_R1:
#     r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
#     r1_tk_dataset: dict[str, TokenizedSrcSet] = r0_extra["R1-tk_dataset"]
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(ctx_size=4096, left_margin=2048, right_margin=1024), sampling_max_tokens=32768, max_workers=20)


In [4]:
# load the critics

from typet5.critic import CriticModel, get_critic_name

critics = dict[bool, CriticModel]()
for new_data in [True, False]:
    critic_name = get_critic_name(
        no_feedback=False, new_data=new_data, config=config._replace(quicktest=False)
    )
    critic = CriticModel.load(datadir / f"checkpoints/lit-saved/{critic_name}")
    critic.to(device)
    critics[new_data] = critic
print("Critics loaded.")


Critics loaded.


In [5]:
# set up the inference

from typet5.model import DatasetPredResult
from typet5.utils import pretty_print_dict, run_long_task, PickleCache
from typet5.model import CtxArgs, DecodingArgs, ModelSPOT

testset = tk_dataset["test"]
# if not config.quicktest:
#     testset = testset[1:-1:10]

# used for inference
n_samples = 16
dec_ctx_args = config.dec_ctx_args()
dec_ctx_args.max_labels = 1  # one type per chunk
greedy_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=dec_ctx_args,
    max_workers=28,
    do_sample=False,
)

sample_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=dec_ctx_args,
    max_workers=28,
    do_sample=True,
    top_p=0.9,
)

bs_args = DecodingArgs(
    sampling_max_tokens=max_tokens_per_file,
    ctx_args=dec_ctx_args,
    max_workers=28,
    do_sample=False,
    num_beams=n_samples,
)

bs_incr_args = DecodingArgs(
    ctx_args=dec_ctx_args,
    sampling_max_tokens=max_tokens_per_file,
    max_workers=28,
    tokens_per_type=10,
    do_sample=False,
    num_beams=n_samples,
)


eval_cache = PickleCache(proj_root() / "caches" / "inference_spot" / r0_model_name)
# eval_cache.clear()


In [6]:
# compute results
from typet5.decode import (
    sample_candidates,
    select_candidates_by_type_errors,
    select_candidates_using_oracle,
    select_candidates_using_critic,
    select_first_candidates,
    incr_inference_with_feedback,
    SelectByOracle,
    SelectByCounting,
    SelectByCritic,
)

def score_transform(x: float):
    if x <= 0.2:
        return -1.0
    if x >= 0.8:
        return 1.0
    return 0.0

results = dict[str, DatasetPredResult]()
incr_results = dict[str, Any]()

with run_long_task("Computing results"):
    # r0_wrapper.args = bs_args
    # results["BS"] = evaluate_model(r0_wrapper, None, testset, eval_cache=eval_cache, tc_args=tc_args)[0][1]

    r0_wrapper.args = bs_incr_args

    incr_results["IncrCount"] = eval_cache.cached(
        "Result-IncrCount",
        lambda: incr_inference_with_feedback(
            r0_wrapper,
            testset,
            beam_width=8,
            selector=SelectByCounting(),
            log_to=proj_root() / "caches/IncrCount-Examples",
        ),
    )

    # incr_results["IncrCritic"] = eval_cache.cached(
    #     "Result-IncrCritic",
    #     lambda: incr_inference_with_feedback(
    #         r0_wrapper,
    #         testset,
    #         beam_width=8,
    #         selector=SelectByCritic(critics[False], score_transform),
    #         log_to=proj_root() / "caches/IncrCritic-Examples",
    #     ),
    # )

    # incr_results["IncrCritic-new"] = eval_cache.cached(
    #     "Result-IncrCritic-new",
    #     lambda: incr_inference_with_feedback(
    #         r0_wrapper,
    #         testset,
    #         beam_width=8,
    #         selector=SelectByCritic(critics[True], score_transform),
    #         log_to=proj_root() / "caches/IncrCritic-new-Examples",
    #     ),
    # )

    incr_results["IncrOracle"] = eval_cache.cached(
        "Result-IncrOracle",
        lambda: incr_inference_with_feedback(
            r0_wrapper,
            testset,
            beam_width=8,
            selector=SelectByOracle(),
            log_to=proj_root() / "caches/IncrOracle-Examples",
        ),
    )

    # r0_wrapper.args = bs_args
    # test_chunks, pred_candidates = eval_cache.cached(
    #     "sample_candidates",
    #     lambda: sample_candidates(r0_wrapper, testset, n_samples=n_samples),
    # )

    # results["BS"] = select_first_candidates(test_chunks, pred_candidates)

    # results["Counting"] = eval_cache.cached(
    #     "Result-Counting",
    #     lambda: select_candidates_by_type_errors(testset, test_chunks, pred_candidates),
    # )

    # critic = critics[False]
    # r_name = "Critic"
    # results[r_name] = eval_cache.cached(
    #     f"Result-{r_name}",
    #     lambda: select_candidates_using_critic(
    #         critic,
    #         False,
    #         testset,
    #         test_chunks,
    #         pred_candidates,
    #         dec_args=greedy_args,
    #         # score_transform=score_transform,
    #     ),
    # )

    # results["Oracle"] = eval_cache.cached(
    #     "Result-Oracle",
    #     lambda: select_candidates_using_oracle(test_chunks, pred_candidates),
    # )


Starting task: Computing results
Pushover: (Finished: 'Computing results'.) Time taken: 4.0s


In [7]:
from typet5.visualization import display_persist, visualize_dicts
from typet5.data import src_preds_to_accuracies
from typet5.visualization import display_persist, dict_widget


accs_list = [x.accuracies for x in results.values()]
titles = list(results.keys())

for n, r in incr_results.items():
    accs = src_preds_to_accuracies(r[1], r[0])
    accs_list.append(accs)
    titles.append(n)

display_persist(visualize_dicts(accs_list, titles))


In [8]:
from typet5.utils import pd, display

grouped_res = results["BS + critic-False"].group_by_repo()
grouped_full_acc = {k: v.accuracies["full_acc"] for k, v in grouped_res.items()}
repos = list(grouped_full_acc.keys())
repos.sort(key=lambda x: grouped_full_acc[x].n_total, reverse=True)

grouped_acc_bs = {
    k: v.accuracies["full_acc"] for k, v in results["BS"].group_by_repo().items()
}
grouped_oracle_bs = {
    k: v.accuracies["full_acc"]
    for k, v in results["BS + oracle"].group_by_repo().items()
}

table = pd.DataFrame(
    {
        "Repo": [r.name for r in repos],
        "BS": [str(grouped_acc_bs[r]) for r in repos],
        "Critic": [str(grouped_full_acc[r]) for r in repos],
        "Oracle": [str(grouped_oracle_bs[r]) for r in repos],
    }
)
display(table)


KeyError: 'BS + critic-False'

In [None]:
from typet5.utils import not_none
from typet5.visualization import visualize_preds_on_code

critic_eval = results[critic_result_name(False)]
preds_extra = {
    "critic_preds": [
        x.candidate_label_scores[x.best_candidate] for x in critic_eval.extra_info
    ]
}
visualize_preds_on_code(critic_eval.chunks, critic_eval.predictions, preds_extra)


VBox(children=(IntSlider(value=0, continuous_update=False, max=178), VBox(children=(VBox(children=(Box(childre…

In [None]:
from typet5.decode import collect_type_errors_from_predictions
from typet5.model import DatasetPredResult
from typet5.type_check import PythonType, MypyFeedback
from typet5.data import TokenizedSrcSet


def collect_base_errors(dataset: TokenizedSrcSet):
    "Collect the type errors triggered by replacing all labels with `Any`."
    chunks = dataset.to_chunks(
        r0_wrapper.args.ctx_args, tqdm_args={"disable": True}
    )
    dummy_preds = [
        [PythonType(("Any",)) for _ in info.types] for info in chunks.chunks_info
    ]
    pred_r = DatasetPredResult(chunks, dummy_preds)
    return collect_type_errors_from_predictions(dataset, pred_r, max_workers=30)


def collect_gold_errors(dataset: TokenizedSrcSet):
    "Collect the type errors triggered by ground-truth labels."
    chunks = dataset.to_chunks(
        r0_wrapper.args.ctx_args, tqdm_args={"disable": True}
    )
    label_preds = [info.types for info in chunks.chunks_info]
    pred_r = DatasetPredResult(chunks, label_preds)
    return collect_type_errors_from_predictions(dataset, pred_r, max_workers=30)


num_labels = sum(len(s.types) for s in testset.all_srcs)
print("Total number of labels: ", num_labels)
type_errors = dict[str, list[tuple[Path, MypyFeedback]]]()
type_errors["default"] = collect_base_errors(testset)
type_errors["gold"] = collect_gold_errors(testset)
for k, v in results.items():
    type_errors[k] = collect_type_errors_from_predictions(testset, v, max_workers=30)

from typet5.visualization import dict_widget, display_persist

display_persist(dict_widget({k: len(v) for k, v in type_errors.items()}))


In [None]:
from typet5.visualization import seq_flatten, visualize_counts
from typet5.utils import Counter
from typet5.type_check import count_type_frequency


def show_type_distr(recursive: bool, top_k: int):
    counts = dict[str, Counter]()
    for name in ["greedy", "BS + feedback"]:
        types = seq_flatten(results[name].predictions)
        counts[name] = count_type_frequency(types, recursive=recursive)

    display(visualize_counts(counts, x_name="Predicted Type", top_k=top_k))


show_type_distr(recursive=True, top_k=15)


In [None]:
from typet5.visualization import visualize_counts, visualize_sequence_tabs, display
from typet5.utils import Counter

default_counts = Counter(e.error_code for _, e in type_errors["default"])

error_counts = dict[str, Counter]()
for name in ["gold"]:  # ["greedy", "BS + feedback"]:
    c = Counter(e.error_code for _, e in type_errors[name])
    for e, v in default_counts.items():
        c[e] -= v
    error_counts[name] = c
display(visualize_counts(error_counts, "Error"))


In [None]:
from typet5.visualization import visualize_conf_matrix

visualize_conf_matrix(results)
