# BT4Py HiTypeT5 Top-1

In [7]:
%load_ext autoreload
%autoreload 2

%pwd
%cd /home/benji/Documents/Uni/heidelberg/05/masterarbeit/impls/scripts/experiments

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/benji/Documents/Uni/heidelberg/05/masterarbeit/impls/scripts/experiments


In [8]:
import polars as pl

pl.Config.set_fmt_str_lengths(300)
pl.Config.set_tbl_rows(n=50)

polars.config.Config

In [9]:
import pathlib

from scripts.common.schemas import TypeCollectionCategory
from scripts.infer.structure import DatasetFolderStructure

tool = "HiTypeT5PyN1"
dataset = DatasetFolderStructure(pathlib.Path(
    "/home/benji/Documents/Uni/heidelberg/05/masterarbeit/datasets/better-types-4-py-dataset"
))


In [10]:
import logging
from importlib import reload

logging.shutdown()
reload(logging)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
for handler in logger.handlers:
    logger.removeHandler(handler)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

ch.setFormatter(logging.Formatter(f"[{tool} @ %(levelname)s]: %(message)s"))
logger.addHandler(ch)

logger.info("Hello World!")


[HiTypeT5PyN1 @ INFO]: Hello World!


# Ensure Datapoints line up

In [18]:
import pathlib, shutil, tqdm

from typet5.static_analysis import PythonProject
from typet5.experiments.utils import remove_newer_syntax_for_repo
from typet5.experiments.type4py import Type4PySupportedSyntax

repos_dir = dataset.dataset_root
hity4py_rewritten_repos = pathlib.Path() / "hity4py-repos"
if not hity4py_rewritten_repos.is_dir():
    logger.info("Creating rewritten repos")
    shutil.copytree(repos_dir / "repos" / "test", hity4py_rewritten_repos)
    remove_newer_syntax_for_repo(hity4py_rewritten_repos, Type4PySupportedSyntax)

logger.info(f"Rewritten repos are at {hity4py_rewritten_repos.resolve()}")
name_to_location = {r.name: r for r in (repos_dir / "repos" / "test").iterdir() if r.is_dir()}
print(len(name_to_location))

name_to_parsed = {
    name: PythonProject.parse_from_root(root, discard_bad_files=True) 
    for name, root in tqdm.tqdm(name_to_location.items(), "Parsing labels")
}
# print(name_to_parsed["ActivityWatch__aw-research"])

[HiTypeT5PyN1 @ INFO]: Rewritten repos are at /home/benji/Documents/Uni/heidelberg/05/masterarbeit/impls/scripts/experiments/hity4py-repos


50



Parsing labels:   0%|                                                                                                                                                                                                  | 0/50 [00:00<?, ?it/s][A
Parsing labels:   2%|███▋                                                                                                                                                                                      | 1/50 [00:08<06:36,  8.09s/it][A
Parsing labels:   4%|███████▍                                                                                                                                                                                  | 2/50 [00:09<03:11,  3.99s/it][A
Parsing labels:   6%|███████████▏                                                                                                                                                                              | 3/50 [00:10<02:01,  2.59s/it][A
Parsing labels:   8%|██████████

In [19]:
# Constants
# Common Type Names
from typet5.model import ModelWrapper
model = ModelWrapper.load_from_hub("MrVPlusOne/TypeT5-v7")
common_names = model.common_type_names
del model

from scripts.common.schemas import TypeCollectionCategory

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

In [28]:
import pathlib, collections

# Because our analysis reviews more datapoints than these models actually regard, reuse TypeT5 metrics instead
def typet5_metrics_4_hitt5(task: TypeCollectionCategory | str) -> None:
    def cleanup_predictions(predictions: dict) -> dict:
        # {'global@global': [{'category': 'local', 'name': '__all__', 'type': ['tuple[typing.Text]']}]}
        rewritten_predictions = collections.defaultdict(list)
        for scope, ps in predictions.items():
            for p in ps:
                rewritten_p = p.copy()
                rewritten_p["type"] = [ty.replace("typing.Text", "str") for ty in rewritten_p["type"]]
                
                # if rewritten_p["type"] or rewritten_p["category"] != "local":
                rewritten_predictions[scope].append(rewritten_p)
        
        return rewritten_predictions
    
    from scripts.common.output import InferenceArtifactIO

    test_set = dataset.test_set()

    proj2datasets = [
        (
            project,
            InferenceArtifactIO(
                artifact_root=pathlib.Path(),
                dataset=dataset,
                repository=project,
                tool_name=tool,
                task=task,
            ),
        )
        for project in test_set
    ]
    existing = dict(
        (project, artifact)
        for project, artifact in proj2datasets
        if artifact.full_location().exists()
    )
    
    from importlib import reload

    from libcst import helpers as h
    import tqdm


    from typet5.experiments import hityper, type4py
    from typet5.static_analysis import PythonProject, SignatureMap, AccuracyMetric, SignatureErrorAnalysis
    from typet5.experiments.typet5 import accs_as_table_row
    from typet5.visualization import pretty_print_dict
    import pprint
    reload(hityper)

    assignments = dict[str, SignatureMap]()
    projects = dict[str, PythonProject]()

    for (project, artifact) in tqdm.tqdm(existing.items(), desc=f"Loading predictions from {task}"):
        type4py_predictions, hity4py_predictions = artifact.read()
        
        # rewritten_project_location = name_to_location[project.name]
        rewritten_project_parsed = name_to_parsed[project.name]
        
        # pprint.pprint(hity4py_predictions)
        # print()
        
        sigmap = SignatureMap()
        for fname, mres in hity4py_predictions.items():
            #print(f"{fname=}, {mres=}")
            
            mres_rewritten = cleanup_predictions(mres)
            #print(fname, mres_rewritten)

            # Removing first three path tokens from fname because they represent the temporary directory / tmp and tmpasdas
            root, tmp, tmpfolder, *project_parts = pathlib.Path(fname).parts       
            mname = PythonProject.rel_path_to_module_name(pathlib.Path(*project_parts))
            
            parser = hityper.HiTyperResponseParser(module=mname)
            sigmap.update(parser.parse(mres_rewritten))
        
            #print(f"Adding sigmap for {mname}")
        assignments[project.name] = sigmap
        
        #print(f"Registering labels for {rewritten_project_location.name}")
        projects[project.name] = rewritten_project_parsed
    
    label_signatures: dict[str, SignatureMap] = {
        project_name: {e.path: e.get_signature() for e in project.all_elems()}
        for project_name, project in projects.items()
    }
    
    pred_signatures = assignments
    
    print(f"{len(label_signatures)=}, {len(pred_signatures)=}")
    
    #for (pname, mname), o in zip(module_srcs, assignments):        
        # pred_signatures[pname].update(o)

    # print("LABEL SIGS:", "ActivityWatch__aw-research", label_signatures["ActivityWatch__aw-research"], "\n")
    # print("PRED SIGS:", "ActivityWatch__aw-research", pred_signatures["ActivityWatch__aw-research"], "\n")

    eval_result = type4py.Type4PyEvalResult(
        pred_maps=pred_signatures,
        label_maps=label_signatures,
    )
    
    metrics = AccuracyMetric.default_metrics(common_type_names=common_names)
    # acc_metric = AccuracyMetric(common_type_names=ubiq_names)

    n_annots = sum([e.get_signature().n_annots() for _, p in projects.items() for e in p.all_elems()])
    n_labels = sum([e.n_annotated() for lm in eval_result.label_maps.values() for e in lm.values()])
    
    logger.info(f"n_annots: {n_annots}, n_labels: {n_labels}")
    logger.info(f"Ratio: {n_labels / n_annots}")
    
    accs = {
        m.name: SignatureErrorAnalysis(
            eval_result.pred_maps,
            eval_result.label_maps,
            m,
            error_on_mismatched_signature=False,
        ).accuracies
        for m in metrics
    }
    accs_as_table_row(accs)
    pretty_print_dict(accs)

In [29]:
typet5_metrics_4_hitt5(task="all")


Loading predictions from all:   0%|                                                                                                                                                                                    | 0/50 [00:00<?, ?it/s][A
Loading predictions from all:   8%|█████████████▊                                                                                                                                                              | 4/50 [00:00<00:01, 25.48it/s][A
Loading predictions from all:  14%|████████████████████████                                                                                                                                                    | 7/50 [00:00<00:04,  9.25it/s][A
Loading predictions from all:  24%|█████████████████████████████████████████                                                                                                                                  | 12/50 [00:00<00:02, 16.69it/s][A
Loading predictions from all:  

len(label_signatures)=50, len(pred_signatures)=50


[HiTypeT5PyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[HiTypeT5PyN1 @ INFO]: Ratio: 0.5493847688726305


[((ty'typing.Callable[str, CompileFn]', ty'CompileFn'), 309),
 ((ty'str', ty'int'), 125),
 ((ty'typing.Any', ty'Any'), 97),
 ((ty'PyASTCtx', ty'GeneratorContext'), 79),
 ((ty'Union[None, str]', ty'str'), 76),
 ((ty'Dict[str, typing.Any]', ty'Dict'), 57),
 ((ty'bool', ty'Union[None, bool]'), 44),
 ((ty'ptm.Mock', ty'ptm.MockFixture'), 41),
 ((ty'str', ty'Union[None, str]'), 38),
 ((ty'List', ty'List[str]'), 31),
 ((ty'SpecialForm', ty'ISeq'), 27),
 ((ty'str', ty'runtime.Namespace'), 27),
 ((ty'Dict[str, typing.Any]', ty'Dict[str, Any]'), 25),
 ((ty'None', ty'str'), 24),
 ((ty'List[str]', ty'List'), 24),
 ((ty'str', ty'uuid.UUID'), 24),
 ((ty'None', ty'Any'), 23),
 ((ty'None', ty'bool'), 21),
 ((ty'sym.Namespace', ty'runtime.Namespace'), 21),
 ((ty'logging.Logger', ty'logging.LogCaptureFixture'), 21)]
[((ty'str', ty'int'), 125),
 ((ty'typing.Any', ty'Any'), 97),
 ((ty'Union[None, str]', ty'str'), 76),
 ((ty'Dict[str, typing.Any]', ty'Dict'), 57),
 ((ty'bool', ty'Union[None, bool]'), 44),

In [9]:
typet5_metrics_4_hitt5(task=TypeCollectionCategory.VARIABLE)

Loading predictions from VARIABLE: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 23.09it/s]


len(label_signatures)=50, len(pred_signatures)=50


[HiType4PyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[HiType4PyN1 @ INFO]: Ratio: 0.5493847688726305


Accuracies on all types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
45.09 & 48.55 & 51.66 & 21.78 & 54.87
Accuracies on common types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
62.32 & 61.36 & 64.66 & 30.62 & 66.24
Accuracies on rare types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
9.92 & 29.90 & 32.36 & 10.89 & 34.24


In [10]:
typet5_metrics_4_hitt5(task=TypeCollectionCategory.CALLABLE_RETURN)

Loading predictions from CALLABLE_RETURN: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 23.66it/s]


len(label_signatures)=50, len(pred_signatures)=50


[HiType4PyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[HiType4PyN1 @ INFO]: Ratio: 0.5493847688726305


Accuracies on all types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
45.11 & 48.66 & 51.82 & 21.87 & 54.99
Accuracies on common types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
62.36 & 61.54 & 64.84 & 31.05 & 66.41
Accuracies on rare types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
9.91 & 29.92 & 32.47 & 10.68 & 34.28


In [11]:
typet5_metrics_4_hitt5(task=TypeCollectionCategory.CALLABLE_PARAMETER)

Loading predictions from CALLABLE_PARAMETER: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 23.50it/s]


len(label_signatures)=50, len(pred_signatures)=50


[HiType4PyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[HiType4PyN1 @ INFO]: Ratio: 0.5493847688726305


Accuracies on all types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
45.11 & 48.66 & 51.82 & 21.87 & 54.99
Accuracies on common types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
62.36 & 61.54 & 64.84 & 31.05 & 66.41
Accuracies on rare types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
9.91 & 29.92 & 32.47 & 10.68 & 34.28


In [12]:
typet5_metrics_4_hity4py(task="all")

Loading predictions from all: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:04<00:00, 11.92it/s]


len(label_signatures)=50, len(pred_signatures)=50


[HiType4PyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[HiType4PyN1 @ INFO]: Ratio: 0.5493847688726305


Accuracies on all types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
45.09 & 48.55 & 51.66 & 21.78 & 54.87
Accuracies on common types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
62.32 & 61.36 & 64.66 & 30.62 & 66.24
Accuracies on rare types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
9.92 & 29.90 & 32.36 & 10.89 & 34.24
