# BT4Py Type4Py Top-1

In [5]:
%load_ext autoreload
%autoreload 2

%pwd
%cd /home/benji/Documents/Uni/heidelberg/05/masterarbeit/impls/scripts/experiments

/home/benji/Documents/Uni/heidelberg/05/masterarbeit/impls/scripts/experiments


In [6]:
import polars as pl

pl.Config.set_fmt_str_lengths(300)
pl.Config.set_tbl_rows(n=50)

polars.config.Config

In [7]:
import pathlib

from scripts.common.schemas import TypeCollectionCategory
from scripts.infer.structure import DatasetFolderStructure

tool = "type4pyN1"
dataset = DatasetFolderStructure(pathlib.Path(
    "/home/benji/Documents/Uni/heidelberg/05/masterarbeit/datasets/better-types-4-py-dataset"
))


In [8]:
import logging
from importlib import reload

logging.shutdown()
reload(logging)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
for handler in logger.handlers:
    logger.removeHandler(handler)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

ch.setFormatter(logging.Formatter(f"[{tool} @ %(levelname)s]: %(message)s"))
logger.addHandler(ch)

logger.info("Hello World!")


[type4pyN1 @ INFO]: Hello World!


# Loading Inference Task

In [5]:
from experiments import api

reload(api)

tasks = [
    TypeCollectionCategory.CALLABLE_PARAMETER, 
    TypeCollectionCategory.CALLABLE_RETURN, 
    TypeCollectionCategory.VARIABLE
]
inferreds = {
    task: dict(api.tasks.inferreds(dataset=dataset, tool=tool, task=task))
    for task in tasks
}

for task in tasks:
    task_inferred = inferreds[task]
    assert len(task_inferred) > 0, f"Did not find any datasets for {task}!"
    logger.info(f"Found {len(task_inferred)} files for DataFrame Inference for {task}")



[type4pyN1 @ INFO]: Found 50 files for DataFrame Inference for CALLABLE_PARAMETER
[type4pyN1 @ INFO]: Found 50 files for DataFrame Inference for CALLABLE_RETURN
[type4pyN1 @ INFO]: Found 50 files for DataFrame Inference for VARIABLE


# Loading Ground Truths

In [18]:
from experiments import api

reload(api)

ground_truths = api.tasks.extended_ground_truths(dataset=dataset)
logger.info(f"Found {len(ground_truths)} files with ground truths")

assert len(ground_truths) > 0, f"Did not find any datasets!"


[type4pyN1 @ INFO]: Found 50 files with ground truths


# Ensure Datapoints line up

In [19]:
import tqdm

from scripts.infer.structure import AuthorRepo
from scripts.common.schemas import ExtendedInferredSchema, ExtendedTypeCollectionSchema
import polars as pl

pl.Config.set_tbl_width_chars(1000)

#missing_in_inferred = set(ground_truths.keys()).difference(inferreds.keys())
#missing_in_gt = set(inferreds.keys()).difference(ground_truths.keys())

#if missing_in_inferred:
#    logger.warning(f"Could not find {missing_in_inferred} in inference dataset")

#if missing_in_gt:
#    logger.warning(f"Could not find {missing_in_gt} in inference dataset")

# assert not (missing_in_inferred and missing_in_gt), f"Missing repositories!"

missing_ground_truth = list()
missing_inferred = list()

tasks_str = list(map(str, tasks))

for author_repo in tqdm.tqdm(ground_truths):
    # Select by category
    ground_truth = pl.scan_csv(source=ground_truths[author_repo], null_values=[""]).filter(
        pl.col(ExtendedInferredSchema.category).is_in(tasks_str)
    )
    inferred = pl.concat([
        pl.scan_csv(source=inferreds[task][author_repo], null_values=[""])
        for task in tasks
    ])

    gt_no_anno = ground_truth.drop(columns=ExtendedInferredSchema.anno)
    inferred_no_anno = inferred.drop(
        columns=[ExtendedInferredSchema.anno, ExtendedInferredSchema.method, ExtendedInferredSchema.topn]
    )

    missing_from_inferred = gt_no_anno.join(
        inferred_no_anno,
        on=[
            ExtendedInferredSchema.file,
            ExtendedInferredSchema.category,
            ExtendedInferredSchema.qname,
            ExtendedInferredSchema.qname_ssa,
        ],
        how="anti",
    ).with_columns(pl.lit(f"{author_repo}").alias("repository"))
    missing_from_ground_truth = inferred_no_anno.join(
        gt_no_anno,
        on=[
            ExtendedInferredSchema.file,
            ExtendedInferredSchema.category,
            ExtendedInferredSchema.qname,
            ExtendedInferredSchema.qname_ssa,
        ],
        how="anti",
    ).with_columns(pl.lit(f"{author_repo}").alias("repository"))

    missing_inferred.append(missing_from_inferred)
    missing_ground_truth.append(missing_from_ground_truth)

gt_success = True
for author_repo, missing_entries in zip(
    ground_truths,
    pl.collect_all(missing_inferred, streaming=True, common_subplan_elimination=False),
):
    gt_success = gt_success and missing_entries.is_empty()
    if not missing_entries.is_empty():
        logger.error(f"Anti-Join for Inferred Truth shows missing entries for {author_repo}")
        logger.error(f"{missing_entries}")


inf_success = True
for author_repo, missing_entries in zip(
    ground_truths,
    pl.collect_all(missing_ground_truth, streaming=True, common_subplan_elimination=False),
):
    inf_success = inf_success and missing_entries.is_empty()
    if not missing_entries.is_empty():
        logger.error(f"Anti-Join for Ground Truth shows missing entries for {author_repo}")
        logger.error(f"{missing_entries}")


assert gt_success and inf_success, "Some labels did not line up! Check the output!"


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 275.45it/s]


# Join Datapoints based on file, category and qname_ssa

In [8]:
import polars as pl

queries_for_inner = []
queries_for_missing = []
for author_repo in tqdm.tqdm(ground_truths):
    ground_truth = pl.read_csv(source=ground_truths[author_repo], null_values=[""]).filter(
        pl.col(ExtendedInferredSchema.category).is_in(tasks_str)
    ).with_columns(pl.lit(str(author_repo)).alias("repository"))
    #print(ground_truth)

    inferred = pl.concat([
        pl.read_csv(source=inferreds[task][author_repo], null_values=[""])
        for task in tasks
    ]).with_columns(pl.lit(str(author_repo)).alias("repository"))
    #print(inferred)

    
    gt_vs_inf = ground_truth.join(
        inferred,
        on=[
            "repository",
            ExtendedInferredSchema.file,
            ExtendedInferredSchema.category,
            ExtendedInferredSchema.qname,
            ExtendedInferredSchema.qname_ssa,
        ],
        how="inner",
        # validate="1:1",
        suffix=("_predict"),
    )
    #print(gt_vs_inf)
    # print(gt_vs_inf)
    queries_for_inner.append(gt_vs_inf)


for author_repo, merged, _ in zip(
    ground_truths, 
    # pl.collect_all(queries_for_inner, streaming=True, common_subplan_elimination=False),
    queries_for_inner,
    range(3)
):
    for task in tasks_str:
        task_sample = (
            merged.select([
                pl.col("repository"),
                pl.col(ExtendedInferredSchema.file),
                pl.col(ExtendedInferredSchema.category),
                pl.col(ExtendedInferredSchema.qname_ssa),
                pl.col(ExtendedInferredSchema.anno),
                pl.col("anno_predict"),
                pl.col(ExtendedInferredSchema.parametric_anno),
                pl.col("parametric_anno_predict"),
            ]).filter(pl.col(ExtendedInferredSchema.category) == task).head(n=20))
        
        logger.info(f"{task}, {task_sample}")


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 396.16it/s]
[TypeT5TopN1 @ INFO]: CALLABLE_PARAMETER, shape: (20, 8)
┌──────────────────────────┬────────────────────┬────────────────────┬──────────────────────────────────────────────┬──────┬────────────────────────────────────────┬─────────────────┬────────────────────────────────────────┐
│ repository               ┆ file               ┆ category           ┆ qname_ssa                                    ┆ anno ┆ anno_predict                           ┆ parametric_anno ┆ parametric_anno_predict                │
│ ---                      ┆ ---                ┆ ---                ┆ ---                                          ┆ ---  ┆ ---                                    ┆ ---             ┆ ---                                    │
│ str                      ┆ 

[TypeT5TopN1 @ INFO]: VARIABLE, shape: (20, 8)
┌──────────────────────────┬────────────────────────────────────┬──────────┬───────────────────────────────────────────────────────┬──────┬──────────────┬─────────────────┬─────────────────────────┐
│ repository               ┆ file                               ┆ category ┆ qname_ssa                                             ┆ anno ┆ anno_predict ┆ parametric_anno ┆ parametric_anno_predict │
│ ---                      ┆ ---                                ┆ ---      ┆ ---                                                   ┆ ---  ┆ ---          ┆ ---             ┆ ---                     │
│ str                      ┆ str                                ┆ str      ┆ str                                                   ┆ str  ┆ str          ┆ str             ┆ str                     │
╞══════════════════════════╪════════════════════════════════════╪══════════╪═══════════════════════════════════════════════════════╪══════╪══════════════╪═══

[TypeT5TopN1 @ INFO]: CALLABLE_RETURN, shape: (20, 8)
┌────────────────────────┬──────────────────┬─────────────────┬─────────────────────────┬────────────────────────────────────────┬────────────────────────────────────────┬───────────────────┬─────────────────────────┐
│ repository             ┆ file             ┆ category        ┆ qname_ssa               ┆ anno                                   ┆ anno_predict                           ┆ parametric_anno   ┆ parametric_anno_predict │
│ ---                    ┆ ---              ┆ ---             ┆ ---                     ┆ ---                                    ┆ ---                                    ┆ ---               ┆ ---                     │
│ str                    ┆ str              ┆ str             ┆ str                     ┆ str                                    ┆ str                                    ┆ str               ┆ str                     │
╞════════════════════════╪══════════════════╪═════════════════╪═══════════

[TypeT5TopN1 @ INFO]: CALLABLE_PARAMETER, shape: (20, 8)
┌──────────────────┬─────────────────────────┬────────────────────┬────────────────────────────────────────────┬─────────────────────────────────────────┬─────────────────────────────────────────┬────────────────────────────┬────────────────────────────┐
│ repository       ┆ file                    ┆ category           ┆ qname_ssa                                  ┆ anno                                    ┆ anno_predict                            ┆ parametric_anno            ┆ parametric_anno_predict    │
│ ---              ┆ ---                     ┆ ---                ┆ ---                                        ┆ ---                                     ┆ ---                                     ┆ ---                        ┆ ---                        │
│ str              ┆ str                     ┆ str                ┆ str                                        ┆ str                                     ┆ str                    

[TypeT5TopN1 @ INFO]: VARIABLE, shape: (20, 8)
┌──────────────────┬──────────────────────────┬──────────┬────────────────────────────────────────────────────┬─────────────────────────────┬───────────────────────────────────┬─────────────────┬─────────────────────────┐
│ repository       ┆ file                     ┆ category ┆ qname_ssa                                          ┆ anno                        ┆ anno_predict                      ┆ parametric_anno ┆ parametric_anno_predict │
│ ---              ┆ ---                      ┆ ---      ┆ ---                                                ┆ ---                         ┆ ---                               ┆ ---             ┆ ---                     │
│ str              ┆ str                      ┆ str      ┆ str                                                ┆ str                         ┆ str                               ┆ str             ┆ str                     │
╞══════════════════╪══════════════════════════╪══════════╪═══════

In [67]:
all_ground_truths_vs_predictions = pl.concat(queries_for_inner)
logger.info(f"Unprocessed sample size: {all_ground_truths_vs_predictions.shape}")
        
# Just about every publication does this: Ignore typing.Any, Any and None as it is not helpful to predict these
without_useless_annotations = all_ground_truths_vs_predictions.filter(~pl.col(ExtendedInferredSchema.anno).is_in(["typing.Any", "Any", "None"]))
logger.info(f"After removing typing.Any, Any and None from ground truth: {without_useless_annotations.shape}")
    
# Remove everything where there is no label in the ground truth dataset
no_ground_truth_label = without_useless_annotations.drop_nulls(subset=ExtendedInferredSchema.anno)
logger.info(f"After removing NULLs in ground truth annotation column: {no_ground_truth_label.shape}")

# Remove all whitespace
no_ground_truth_label = no_ground_truth_label.with_columns([
    pl.col(ExtendedInferredSchema.anno).str.replace_all(" ", ""),
    pl.col(ExtendedInferredSchema.parametric_anno).str.replace_all(" ", ""),
    pl.col("anno_predict").str.replace_all(" ", ""),
    pl.col("parametric_anno_predict").str.replace_all(" ", ""),
])
#print(no_ground_truth_label.sample(n=30))
            
# Remove trivial functions from the dataset: 'init' 'str', 'unicode', 'repr', 'len', 'doc', 'sizeof'
# adapted taken from type4py's preprocessing

# Retain everything that is not a callable or (if it is!) does not end with one of these methods
without_trivial_functions = no_ground_truth_label.filter(
    ~((pl.col(ExtendedTypeCollectionSchema.category) == str(TypeCollectionCategory.CALLABLE_RETURN)) &
    (
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__init__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__str__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__unicode__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__repr__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__len__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__doc__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__sizeof__")
    ))
)
logger.info(f"After removing trivially inferrable functions: {without_trivial_functions.shape}")


# Remove parameters for which it makes no sense to make predictions for (cls, self)
without_trivial_parameters = without_trivial_functions.filter(
    ~((pl.col(ExtendedTypeCollectionSchema.category) == str(TypeCollectionCategory.CALLABLE_PARAMETER)) &
    (
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".cls") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".self") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".args") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".kwargs")
    ))
)
logger.info(f"After removing trivially inferrable parameters: {without_trivial_parameters.shape}")


# Remove variables that are not intended to be seen anyway
without_useless_variables = without_trivial_parameters.filter(
    ~((pl.col(ExtendedTypeCollectionSchema.category) == str(TypeCollectionCategory.VARIABLE)) &
    (
        pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with("._")
    ))
)
logger.info(f"After removing useless variables: {without_useless_variables.shape}")


filtered_ground_truths_vs_predictions = without_useless_variables
#logger.info(filtered_ground_truths_vs_predictions.select(
#    pl.col("repository"),
#    pl.col(ExtendedInferredSchema.file),
#    pl.col(ExtendedInferredSchema.category),
#    pl.col(ExtendedInferredSchema.qname_ssa),
#    pl.col(ExtendedInferredSchema.anno),
#    pl.col("anno_predict"),
#).sample(n=20))

# Remove TT5 artifacts
filtered_ground_truths_vs_predictions = filtered_ground_truths_vs_predictions.filter(
    pl.col("anno_predict") != "..."
)


logger.info(filtered_ground_truths_vs_predictions.select(
    pl.col("repository"),
    pl.col(ExtendedInferredSchema.file),
    pl.col(ExtendedInferredSchema.category),
    pl.col(ExtendedInferredSchema.qname_ssa),
    pl.col(ExtendedInferredSchema.anno),
    pl.col("anno_predict"),
).sample(n=20))

# Dequalify all annotations as TypeT5 did
def dequalify(annotation: str | None) -> str | None:
    import libcst
    from libcst import matchers as m
    class Dequalifier(libcst.CSTTransformer):
        def __init__(self) -> None:
            super().__init__()

        def leave_Attribute(
            self, original_node: libcst.Attribute, updated_node: libcst.Attribute
        ) -> libcst.Name:
            return updated_node.attr
        
        def leave_Subscript(
            self,
            original_node: libcst.Subscript,
            updated_node: libcst.Subscript,
        ) -> libcst.Subscript:
            import typing
            from scripts.common import _stringify
            
            if not m.matches(updated_node, m.Subscript(value=m.Name("Union") | m.Attribute(m.Name(), m.Name("Union")))):
                return updated_node

            # Resort union entries
            subscript_elems_as_str = [_stringify(se.slice.value) for se in updated_node.slice]
            as_sorted = sorted(subscript_elems_as_str)

            sorted_union = libcst.parse_expression(
                f"Union[{','.join(as_sorted)}]"
            )

            return typing.cast(libcst.Subscript, sorted_union)
    
    if annotation is None:
        return None
    return libcst.parse_module(annotation).visit(Dequalifier()).code

filtered_ground_truths_vs_predictions = filtered_ground_truths_vs_predictions.with_columns([
    pl.col(ExtendedInferredSchema.anno).apply(lambda a: dequalify(a)).alias("rewritten_anno"),
    pl.col(ExtendedInferredSchema.parametric_anno).apply(lambda a: dequalify(a)).alias("rewritten_parametric_anno"),
    pl.col("anno_predict").apply(lambda a: dequalify(a)).alias("rewritten_anno_predict"),
    pl.col("parametric_anno_predict").apply(lambda a: dequalify(a)).alias("rewritten_parametric_anno_predict"),
])

[TypeT5TopN1 @ INFO]: Unprocessed sample size: (58715, 13)
[TypeT5TopN1 @ INFO]: After removing typing.Any, Any and None from ground truth: (56130, 13)
[TypeT5TopN1 @ INFO]: After removing NULLs in ground truth annotation column: (16012, 13)
[TypeT5TopN1 @ INFO]: After removing trivially inferrable functions: (15909, 13)
[TypeT5TopN1 @ INFO]: After removing trivially inferrable parameters: (15758, 13)
[TypeT5TopN1 @ INFO]: After removing useless variables: (15758, 13)
[TypeT5TopN1 @ INFO]: shape: (20, 6)
┌────────────────────────────────────────┬─────────────────────────────────────────────────────────┬────────────────────┬─────────────────────────────────────────────────────────────────────────────────┬────────────────────────────────────────────────────────────┬────────────────────────────────────────────────────────────┐
│ repository                             ┆ file                                                    ┆ category           ┆ qname_ssa                                 

# Check transformed dataset

In [68]:
logger.info(filtered_ground_truths_vs_predictions.select(
    pl.col("repository"),
    pl.col(ExtendedInferredSchema.file),
    pl.col(ExtendedInferredSchema.category),
    pl.col(ExtendedInferredSchema.qname_ssa),
    pl.col(ExtendedInferredSchema.anno),
    pl.col("rewritten_anno"),
).sample(n=20))

[TypeT5TopN1 @ INFO]: shape: (20, 6)
┌────────────────────────────────────────┬─────────────────────────────────────────────────────────────────┬────────────────────┬───────────────────────────────────────────────────────┬──────────────────────────────────────────────────────────────────┬──────────────────────────────┐
│ repository                             ┆ file                                                            ┆ category           ┆ qname_ssa                                             ┆ anno                                                             ┆ rewritten_anno               │
│ ---                                    ┆ ---                                                             ┆ ---                ┆ ---                                                   ┆ ---                                                              ┆ ---                          │
│ str                                    ┆ str                                                             ┆ st

In [69]:
logger.info(filtered_ground_truths_vs_predictions.select(
    pl.col("repository"),
    pl.col(ExtendedInferredSchema.file),
    pl.col(ExtendedInferredSchema.category),
    pl.col(ExtendedInferredSchema.qname_ssa),
    pl.col(ExtendedInferredSchema.parametric_anno),
    pl.col("rewritten_parametric_anno"),
).sample(n=20))

[TypeT5TopN1 @ INFO]: shape: (20, 6)
┌────────────────────────────────────────┬────────────────────────────────────────────────────────────────┬────────────────────┬──────────────────────────────────────────────────────────────────────────┬──────────────────────────────────────────┬───────────────────────────┐
│ repository                             ┆ file                                                           ┆ category           ┆ qname_ssa                                                                ┆ parametric_anno                          ┆ rewritten_parametric_anno │
│ ---                                    ┆ ---                                                            ┆ ---                ┆ ---                                                                      ┆ ---                                      ┆ ---                       │
│ str                                    ┆ str                                                            ┆ str                ┆ str      

In [70]:
logger.info(filtered_ground_truths_vs_predictions.select(
    pl.col("repository"),
    pl.col(ExtendedInferredSchema.file),
    pl.col(ExtendedInferredSchema.category),
    pl.col(ExtendedInferredSchema.qname_ssa),
    pl.col("anno_predict"),
    pl.col("rewritten_anno_predict"),
).sample(n=20))

[TypeT5TopN1 @ INFO]: shape: (20, 6)
┌────────────────────────────────────────┬─────────────────────────────────────────────┬────────────────────┬─────────────────────────────────────────┬─────────────────────────────────────────────────────────────────────────────────────────┬────────────────────────────────┐
│ repository                             ┆ file                                        ┆ category           ┆ qname_ssa                               ┆ anno_predict                                                                            ┆ rewritten_anno_predict         │
│ ---                                    ┆ ---                                         ┆ ---                ┆ ---                                     ┆ ---                                                                                     ┆ ---                            │
│ str                                    ┆ str                                         ┆ str                ┆ str                         

In [71]:
logger.info(filtered_ground_truths_vs_predictions.select(
    pl.col("repository"),
    pl.col(ExtendedInferredSchema.file),
    pl.col(ExtendedInferredSchema.category),
    pl.col(ExtendedInferredSchema.qname_ssa),
    pl.col("parametric_anno_predict"),
    pl.col("rewritten_parametric_anno_predict"),
).sample(n=20))

[TypeT5TopN1 @ INFO]: shape: (20, 6)
┌──────────────────────────────────────┬─────────────────────────────────────────────────────┬────────────────────┬──────────────────────────────────────────────────────────────────────────────────┬──────────────────────────────────────────────────────────────────────────┬───────────────────────────────────┐
│ repository                           ┆ file                                                ┆ category           ┆ qname_ssa                                                                        ┆ parametric_anno_predict                                                  ┆ rewritten_parametric_anno_predict │
│ ---                                  ┆ ---                                                 ┆ ---                ┆ ---                                                                              ┆ ---                                                                      ┆ ---                               │
│ str                            

# Accept transformed dataset

In [72]:
filtered_ground_truths_vs_predictions = filtered_ground_truths_vs_predictions.with_columns([
    pl.col("rewritten_anno").alias(ExtendedInferredSchema.anno),
    pl.col("rewritten_parametric_anno").alias(ExtendedInferredSchema.parametric_anno),
    pl.col("rewritten_anno_predict").alias("anno_predict"),
    pl.col("rewritten_parametric_anno_predict").alias("parametric_anno_predict")
]).drop(["rewritten_anno", "rewritten_parametric_anno", "rewritten_anno_predict", "rewritten_parametric_anno_predict"])

In [73]:
logger.info(filtered_ground_truths_vs_predictions.select(
    pl.col("repository"),
    pl.col(ExtendedInferredSchema.file),
    pl.col(ExtendedInferredSchema.category),
    pl.col(ExtendedInferredSchema.qname_ssa),
    pl.col("anno"),
    pl.col("anno_predict"),
    pl.col("parametric_anno"),
    pl.col("parametric_anno_predict"),
).sample(n=20))

[TypeT5TopN1 @ INFO]: shape: (20, 8)
┌──────────────────────────────────────┬─────────────────────────────────────────────────────────────┬────────────────────┬─────────────────────────────────────────────────────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬─────────────────────┬─────────────────────────┐
│ repository                           ┆ file                                                        ┆ category           ┆ qname_ssa                                                                       ┆ anno                              ┆ anno_predict                      ┆ parametric_anno     ┆ parametric_anno_predict │
│ ---                                  ┆ ---                                                         ┆ ---                ┆ ---                                                                             ┆ ---                               ┆ ---                               ┆ ---                 ┆ ---        

#  Prediction Metrics for Full Accuracy with removed Variables

In [2]:
# Constants
# Common Type Names
from typet5.model import ModelWrapper
model = ModelWrapper.load_from_hub("MrVPlusOne/TypeT5-v7")
common_names = model.common_type_names
del model

2023-06-26 11:00:57.436176: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

In [3]:
from scripts.common.schemas import TypeCollectionCategory

import pprint

# Because our analysis reviews more datapoints than these models actually regard, reuse TypeT5 metrics instead
def typet5_metrics_4_type4py(task: TypeCollectionCategory | str) -> None:
    from scripts.common.output import InferenceArtifactIO

    test_set = dataset.test_set()

    proj2datasets = [
        (
            project,
            InferenceArtifactIO(
                artifact_root=pathlib.Path(),
                dataset=dataset,
                repository=project,
                tool_name=tool,
                task=task,
            ),
        )
        for project in test_set
    ]
    existing = dict(
        (project, artifact)
        for project, artifact in proj2datasets
        if artifact.full_location().exists()
    )
    
    from importlib import reload

    from libcst import helpers as h
    import tqdm


    from typet5.experiments import type4py
    from typet5.static_analysis import PythonProject, SignatureMap, AccuracyMetric, SignatureErrorAnalysis
    from typet5.experiments.typet5 import accs_as_table_row
    from typet5.visualization import pretty_print_dict
    reload(type4py)

    assignments = []
    projects = []

    for ctr, (project, artifact) in tqdm.tqdm(enumerate(existing.items()), desc=f"Loading labels and predictions from {task}"):            
        type4py_predictions, = artifact.read()
        # if ctr < 1:
            # pprint.pprint(type4py_predictions)

        for file, predictions in type4py_predictions.items():
            modpkg = h.calculate_module_and_package(repo_root=project, filename=project / file)
            parser = type4py.Type4PyResponseParser(modpkg.name)
            assignments.append(parser.parse({"response": predictions}))

        projects.append(PythonProject.parse_from_root(project))


    name2project = {p.name: p for p in projects}
    
    label_signatures: dict[str, SignatureMap] = {
        project.name: {e.path: e.get_signature() for e in project.all_elems()}
        for project in projects
    }
    pred_signatures: dict[str, SignatureMap] = {n: dict() for n in label_signatures}

    module_srcs = [
        (project.name, name)
        for project in projects
        for name in project.modules
    ]
    for (pname, mname), o in zip(module_srcs, assignments):
        if isinstance(o, str):
            if list(name2project[pname].modules[mname].all_elements()):
                # only warn for non-empty modules
                logger.warning(f"In project {pname} module {mname}, Type4Py errored: {o}")
        else:
            pred_signatures[pname].update(o)

    # print(pred_signatures)

    eval_result = type4py.Type4PyEvalResult(
        pred_maps=pred_signatures,
        label_maps=label_signatures,
    )
    
    metrics = AccuracyMetric.default_metrics(common_type_names=common_names)
    # acc_metric = AccuracyMetric(common_type_names=ubiq_names)

    n_annots = sum([e.get_signature().n_annots() for p in projects for e in p.all_elems()])
    n_labels = sum([e.n_annotated() for lm in eval_result.label_maps.values() for e in lm.values()])
    
    logger.info(f"n_annots: {n_annots}, n_labels: {n_labels}")
    logger.info(f"Ratio: {n_labels / n_annots}")
    
    accs = {
        m.name: SignatureErrorAnalysis(
            eval_result.pred_maps,
            eval_result.label_maps,
            m,
            error_on_mismatched_signature=False,
        ).accuracies
        for m in metrics
    }
    accs_as_table_row(accs)
    # pretty_print_dict(accs)

In [22]:
typet5_metrics_4_type4py(task=TypeCollectionCategory.VARIABLE)

Loading labels and predictions from VARIABLE: 50it [04:36,  5.54s/it]
[type4pyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[type4pyN1 @ INFO]: Ratio: 0.5493847688726305


Accuracies on all types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
42.48 & 46.62 & 48.49 & 29.27 & 48.79
Accuracies on common types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
60.78 & 63.01 & 64.80 & 45.48 & 61.35
Accuracies on rare types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
1.20 & 19.04 & 20.67 & 5.13 & 22.09


In [23]:
typet5_metrics_4_type4py(task=TypeCollectionCategory.CALLABLE_RETURN)

Loading labels and predictions from CALLABLE_RETURN: 50it [04:47,  5.75s/it]
[type4pyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[type4pyN1 @ INFO]: Ratio: 0.5493847688726305


Accuracies on all types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
42.48 & 46.69 & 48.59 & 29.26 & 48.87
Accuracies on common types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
60.79 & 63.14 & 64.95 & 45.62 & 61.49
Accuracies on rare types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
1.20 & 19.03 & 20.69 & 5.05 & 22.08


In [9]:
typet5_metrics_4_type4py(task=TypeCollectionCategory.CALLABLE_PARAMETER)

Loading labels and predictions from CALLABLE_PARAMETER: 50it [05:06,  6.13s/it]
[type4pyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[type4pyN1 @ INFO]: Ratio: 0.5493847688726305


Accuracies on all types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
42.48 & 46.69 & 48.59 & 29.26 & 48.87
Accuracies on common types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
60.79 & 63.14 & 64.95 & 45.62 & 61.49
Accuracies on rare types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
1.20 & 19.03 & 20.69 & 5.05 & 22.08


In [11]:
typet5_metrics_4_type4py(task="all")

Loading labels and predictions from all: 50it [03:41,  4.44s/it]
[type4pyN1 @ INFO]: n_annots: 30070, n_labels: 16520
[type4pyN1 @ INFO]: Ratio: 0.5493847688726305


Accuracies on all types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
42.48 & 46.62 & 48.49 & 29.27 & 48.79
Accuracies on common types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
60.78 & 63.01 & 64.80 & 45.48 & 61.35
Accuracies on rare types:
header:  ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']
1.20 & 19.04 & 20.67 & 5.13 & 22.09


In [75]:
from sklearn import metrics, preprocessing
import numpy as np

logger.info(f"Computing metrics on {len(filtered_ground_truths_vs_predictions)} samples")


logger.info("=== Adjusted Accuracy on All Types ===")
typet5_metrics(filtered_ground_truths_vs_predictions, column=ExtendedInferredSchema.anno, simple_complex=None)

logger.info("=== Adjusted Accuracy on Simple Types ===")
typet5_metrics(filtered_ground_truths_vs_predictions, column=ExtendedInferredSchema.anno, simple_complex="simple")

logger.info("=== Adjusted Accuracy on Complex Types ===")
typet5_metrics(filtered_ground_truths_vs_predictions, column=ExtendedInferredSchema.anno, simple_complex="complex")


[TypeT5TopN1 @ INFO]: Computing metrics on 9629 samples
[TypeT5TopN1 @ INFO]: === Adjusted Accuracy on All Types ===
[TypeT5TopN1 @ INFO]: accuracy=0.7162737563609929
[TypeT5TopN1 @ INFO]: recall=0.4157398135142898
[TypeT5TopN1 @ INFO]: precision=0.4294910900648399
[TypeT5TopN1 @ INFO]: f1score=0.4064521212440365
[TypeT5TopN1 @ INFO]: === Adjusted Accuracy on Simple Types ===
[TypeT5TopN1 @ INFO]: accuracy=0.8057907675498498
[TypeT5TopN1 @ INFO]: recall=0.5157027665199498
[TypeT5TopN1 @ INFO]: precision=0.5431168458397478
[TypeT5TopN1 @ INFO]: f1score=0.510972058945364
[TypeT5TopN1 @ INFO]: === Adjusted Accuracy on Complex Types ===
[TypeT5TopN1 @ INFO]: accuracy=0.43216298222800176
[TypeT5TopN1 @ INFO]: recall=0.25617580870177764
[TypeT5TopN1 @ INFO]: precision=0.2853715979732787
[TypeT5TopN1 @ INFO]: f1score=0.26076723155109094


In [76]:
from sklearn import metrics, preprocessing
import numpy as np

logger.info(f"Computing metrics on {len(filtered_ground_truths_vs_predictions)} samples")


logger.info("=== Base Accuracy on All Types ===")
typet5_metrics(filtered_ground_truths_vs_predictions, column=ExtendedInferredSchema.parametric_anno, simple_complex=None)


[TypeT5TopN1 @ INFO]: Computing metrics on 9629 samples
[TypeT5TopN1 @ INFO]: === Base Accuracy on All Types ===
[TypeT5TopN1 @ INFO]: accuracy=0.7751583757399523
[TypeT5TopN1 @ INFO]: recall=0.5992065365372685
[TypeT5TopN1 @ INFO]: precision=0.6050543966959154
[TypeT5TopN1 @ INFO]: f1score=0.5794400364724952


In [14]:
from sklearn import metrics, preprocessing

builtin_types = np.unique(filtered_ground_truths_vs_predictions.filter(
    pl.col(ExtendedInferredSchema.parametric_anno).str.starts_with("builtins.")
).get_column(ExtendedInferredSchema.parametric_anno).to_numpy())
print(builtin_types)
display = metrics.ConfusionMatrixDisplay.from_predictions(y_true=groundtruth_labels, y_pred=model_predictions, labels=builtin_types, normalize="true")

['builtins.Exception' 'builtins.UnicodeError' 'builtins.bool'
 'builtins.bytearray' 'builtins.bytes' 'builtins.complex' 'builtins.dict'
 'builtins.float' 'builtins.int' 'builtins.list' 'builtins.object'
 'builtins.set' 'builtins.str' 'builtins.tuple' 'builtins.type']


NameError: name 'groundtruth_labels' is not defined

In [None]:
mock_mistakes = filtered_ground_truths_vs_predictions.select(pl.all()).filter(pl.col(ExtendedInferredSchema.parametric_anno).str.ends_with("Mock"))
#print(mock_mistakes.select(
#    pl.col("repository"),
#    pl.col(ExtendedInferredSchema.file),
#    pl.col(ExtendedInferredSchema.qname_ssa),
#    pl.col(ExtendedInferredSchema.parametric_anno),
#    pl.col("parametric_anno_predict"),
#))

groundtruth_labels = mock_mistakes.get_column(ExtendedInferredSchema.parametric_anno).to_numpy()
model_predictions = mock_mistakes.get_column("parametric_anno_predict").to_numpy()


display = metrics.ConfusionMatrixDisplay.from_predictions(
    y_true=groundtruth_labels, 
    y_pred=model_predictions, 
    labels=np.unique(np.concatenate((groundtruth_labels, model_predictions))), 
    normalize="true"
)

In [None]:
#  Prediction Metrics for Full Accuracy