In [1]:
%load_ext autoreload
%autoreload 2

%pwd
%cd /home/benji/Documents/Uni/heidelberg/05/masterarbeit/impls/scripts/experiments

/home/benji/Documents/Uni/heidelberg/05/masterarbeit/impls/scripts/experiments


In [19]:
import pandas as pd

pd.set_option('display.max_columns', None)  # or 1000
pd.set_option('display.max_rows', None)  # or 1000
pd.set_option('display.max_colwidth', None)  # or 199


import pathlib, tqdm

from scripts.common.schemas import TypeCollectionCategory
from scripts.infer.structure import DatasetFolderStructure

tool = "HiType4PyN1"
dataset = DatasetFolderStructure(pathlib.Path(
    "/home/benji/Documents/Uni/heidelberg/05/masterarbeit/datasets/better-types-4-py-dataset"
))
tasks = [TypeCollectionCategory.VARIABLE, TypeCollectionCategory.CALLABLE_PARAMETER, TypeCollectionCategory.CALLABLE_RETURN]

projects = list(dataset.test_set())


In [20]:
import logging
from importlib import reload

logging.shutdown()
reload(logging)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
for handler in logger.handlers:
    logger.removeHandler(handler)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

ch.setFormatter(logging.Formatter(f"[{tool} @ %(levelname)s]: %(message)s"))
logger.addHandler(ch)

logger.info("Hello World!")

[HiType4PyN1 @ INFO]: Hello World!


In [37]:
from pandas.errors import MergeError

from scripts.common.schemas import TypeCollectionSchema
from scripts.common.output import DatasetIO, InferredIO

for project in (pbar := tqdm.tqdm(projects, desc="Ensuring inner join has 1 to 1 correlation with ground truth labels")):
    repository = str(dataset.author_repo(project))
    pbar.set_postfix({"repo": repository})
    
    ground_truth_dataset_io = DatasetIO(
        artifact_root=pathlib.Path(),
        dataset=dataset,
        repository=project,
    )
    
    pred_datasets_io = [InferredIO(
        artifact_root=pathlib.Path(),
        dataset=dataset,
        repository=project,
        tool_name=tool,
        task=task
    ) for task in tasks]
    
    
    ground_truth = ground_truth_dataset_io.read()
    ground_truth_for_task = ground_truth[
        ground_truth[TypeCollectionSchema.category].isin(tasks)
    ].assign(repository=repository)
    
    pred_dataset = pd.concat([
        label_dataset_io.read() 
        for label_dataset_io in label_datasets_io
    ], ignore_index=True).assign(repository=repository)
    
    
    #### Find and remove duplicate ground truth labels
    duplicate_ground_truths = ground_truth_for_task[ground_truth_for_task.duplicated(
        subset=[TypeCollectionSchema.file, TypeCollectionSchema.category, TypeCollectionSchema.qname, TypeCollectionSchema.qname_ssa], 
        keep=False
    )]
    if not duplicate_ground_truths.empty:
        logger.warning(f"{len(duplicate_ground_truths)} duplicate keys found in ground truth: \n{duplicate_ground_truths.to_string()}")
        ground_truth_for_task = ground_truth_for_task.drop_duplicates(
            subset=[TypeCollectionSchema.file, TypeCollectionSchema.category, TypeCollectionSchema.qname, TypeCollectionSchema.qname_ssa], 
            keep=False,
            ignore_index=True,
        )
    
    
    #### Find and remove duplicate inference labels
    duplicate_predictions = pred_dataset[pred_dataset.duplicated(
        subset=[TypeCollectionSchema.file, TypeCollectionSchema.category, TypeCollectionSchema.qname, TypeCollectionSchema.qname_ssa], 
        keep=False,
    )]
    if not duplicate_predictions.empty:
        logger.warning(f"{len(duplicate_predictions)} duplicate keys found in prediction dataset: \n{duplicate_predictions.to_string()}")
        pred_dataset = pred_dataset.drop_duplicates(
            subset=[TypeCollectionSchema.file, TypeCollectionSchema.category, TypeCollectionSchema.qname, TypeCollectionSchema.qname_ssa], 
            keep=False,
            ignore_index=True,
        )

    
    ### Validate one-to-one post-dedup
    pd.merge(
        left=ground_truth_for_task,
        right=pred_dataset,
        on=[TypeCollectionSchema.file, TypeCollectionSchema.category, TypeCollectionSchema.qname, TypeCollectionSchema.qname_ssa],
        validate="1:1"
    )
    
    

                        file            category                     qname                 qname_ssa          anno       method  topn                repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  linw1995__data_extractor
774  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  linw1995__data_extractor
904  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters  builtins.int  HiType4PyN1     1  linw1995__data_extractor
905  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters           NaN  HiType4PyN1     1  linw1995__data_extractor
                        file            category                     qname                 qname_ssa          anno       method  topn              repository
773  gpxtrackposter/track.py  CALLABLE_PAR

                        file            category                     qname                 qname_ssa          anno       method  topn                          repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  ShadowTemplate__beautiful-python-3
774  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  ShadowTemplate__beautiful-python-3
904  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters  builtins.int  HiType4PyN1     1  ShadowTemplate__beautiful-python-3
905  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters           NaN  HiType4PyN1     1  ShadowTemplate__beautiful-python-3
                        file            category                     qname                 qname_ssa          anno       method  topn             repo

                               file            category                 qname             qname_ssa anno            repository
924  tests/test_delegated_lookup.py     CALLABLE_RETURN       Provider.lookup       Provider.lookup  NaN  AxelVoitier__lookups
925  tests/test_delegated_lookup.py  CALLABLE_PARAMETER  Provider.lookup.self  Provider.lookup.self  NaN  AxelVoitier__lookups
926  tests/test_delegated_lookup.py     CALLABLE_RETURN       Provider.lookup       Provider.lookup  NaN  AxelVoitier__lookups
927  tests/test_delegated_lookup.py  CALLABLE_PARAMETER  Provider.lookup.self  Provider.lookup.self  NaN  AxelVoitier__lookups
                        file            category                     qname                 qname_ssa          anno       method  topn            repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  AxelVoitier__lookups
774  gpxtrackposter/track.py  CALLABLE_PARAMETER  Tra

                        file            category                     qname                 qname_ssa          anno       method  topn          repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  srittau__FakeSMTPd
774  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  srittau__FakeSMTPd
904  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters  builtins.int  HiType4PyN1     1  srittau__FakeSMTPd
905  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters           NaN  HiType4PyN1     1  srittau__FakeSMTPd
                        file            category                     qname                 qname_ssa          anno       method  topn                  repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meter

                                   file            category                             qname                         qname_ssa                                              anno               repository
102        src/topicdb/models/entity.py     CALLABLE_RETURN                Entity.instance_of                Entity.instance_of                                      builtins.str  brettkromkamp__topic-db
103        src/topicdb/models/entity.py  CALLABLE_PARAMETER           Entity.instance_of.self           Entity.instance_of.self                                               NaN  brettkromkamp__topic-db
104        src/topicdb/models/entity.py     CALLABLE_RETURN                Entity.instance_of                Entity.instance_of                                              None  brettkromkamp__topic-db
105        src/topicdb/models/entity.py  CALLABLE_PARAMETER           Entity.instance_of.self           Entity.instance_of.self                                               NaN  brettkrom

                        file            category                     qname                 qname_ssa          anno       method  topn               repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  brettkromkamp__topic-db
774  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  brettkromkamp__topic-db
904  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters  builtins.int  HiType4PyN1     1  brettkromkamp__topic-db
905  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters           NaN  HiType4PyN1     1  brettkromkamp__topic-db
                        file            category                     qname                 qname_ssa          anno       method  topn                 repository
773  gpxtrackposter/track.py  CALLABLE_PARAM

                      file            category                             qname                         qname_ssa           anno   repository
5    aql/engines/sqlite.py     CALLABLE_RETURN       SqliteConnection.autocommit       SqliteConnection.autocommit  builtins.bool  jreese__aql
6    aql/engines/sqlite.py  CALLABLE_PARAMETER  SqliteConnection.autocommit.self  SqliteConnection.autocommit.self            NaN  jreese__aql
7    aql/engines/sqlite.py     CALLABLE_RETURN       SqliteConnection.autocommit       SqliteConnection.autocommit           None  jreese__aql
8    aql/engines/sqlite.py  CALLABLE_PARAMETER  SqliteConnection.autocommit.self  SqliteConnection.autocommit.self            NaN  jreese__aql
479    aql/engines/base.py     CALLABLE_RETURN             Connection.autocommit             Connection.autocommit  builtins.bool  jreese__aql
480    aql/engines/base.py  CALLABLE_PARAMETER        Connection.autocommit.self        Connection.autocommit.self            NaN  jreese__aql

                        file            category                     qname                 qname_ssa          anno       method  topn      repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  JakobGM__quelf
774  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  JakobGM__quelf
904  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters  builtins.int  HiType4PyN1     1  JakobGM__quelf
905  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters           NaN  HiType4PyN1     1  JakobGM__quelf
                        file            category                     qname                 qname_ssa          anno       method  topn              repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_met

                        file            category                     qname                 qname_ssa          anno       method  topn                     repository
773  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  everyclass__everyclass-server
774  gpxtrackposter/track.py  CALLABLE_PARAMETER  Track.length_meters.self  Track.length_meters.self           NaN  HiType4PyN1     1  everyclass__everyclass-server
904  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters  builtins.int  HiType4PyN1     1  everyclass__everyclass-server
905  gpxtrackposter/track.py     CALLABLE_RETURN       Track.length_meters       Track.length_meters           NaN  HiType4PyN1     1  everyclass__everyclass-server
                        file            category                     qname                 qname_ssa            anno             repository
445  gpxtrackposter/track.py     CA

In [None]:
all_ground_truths_vs_predictions = pl.concat(queries_for_inner)
logger.info(f"Unprocessed sample size: {all_ground_truths_vs_predictions.shape}")
        
# Just about every publication does this: Ignore typing.Any, Any and None as it is not helpful to predict these
without_useless_annotations = all_ground_truths_vs_predictions.filter(~pl.col(ExtendedInferredSchema.anno).is_in(["typing.Any", "Any", "None"]))
logger.info(f"After removing typing.Any, Any and None from ground truth: {without_useless_annotations.shape}")
    
# Remove everything where there is no label in the ground truth dataset
no_ground_truth_label = without_useless_annotations.drop_nulls(subset=ExtendedInferredSchema.anno)
logger.info(f"After removing NULLs in ground truth annotation column: {no_ground_truth_label.shape}")

# Remove all whitespace
no_ground_truth_label = no_ground_truth_label.with_columns([
    pl.col(ExtendedInferredSchema.anno).str.replace_all(" ", ""),
    pl.col(ExtendedInferredSchema.parametric_anno).str.replace_all(" ", ""),
    pl.col("anno_predict").str.replace_all(" ", ""),
    pl.col("parametric_anno_predict").str.replace_all(" ", ""),
])
#print(no_ground_truth_label.sample(n=30))
            
# Remove trivial functions from the dataset: 'init' 'str', 'unicode', 'repr', 'len', 'doc', 'sizeof'
# adapted taken from type4py's preprocessing

# Retain everything that is not a callable or (if it is!) does not end with one of these methods
without_trivial_functions = no_ground_truth_label.filter(
    ~((pl.col(ExtendedTypeCollectionSchema.category) == str(TypeCollectionCategory.CALLABLE_RETURN)) &
    (
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__init__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__str__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__unicode__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__repr__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__len__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__doc__") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".__sizeof__")
    ))
)
logger.info(f"After removing trivially inferrable functions: {without_trivial_functions.shape}")


# Remove parameters for which it makes no sense to make predictions for (cls, self)
without_trivial_parameters = without_trivial_functions.filter(
    ~((pl.col(ExtendedTypeCollectionSchema.category) == str(TypeCollectionCategory.CALLABLE_PARAMETER)) &
    (
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".cls") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".self") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".args") |
            pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with(".kwargs")
    ))
)
logger.info(f"After removing trivially inferrable parameters: {without_trivial_parameters.shape}")


# Remove variables that are not intended to be seen anyway
without_useless_variables = without_trivial_parameters.filter(
    ~((pl.col(ExtendedTypeCollectionSchema.category) == str(TypeCollectionCategory.VARIABLE)) &
    (
        pl.col(ExtendedTypeCollectionSchema.qname).str.ends_with("._")
    ))
)
logger.info(f"After removing useless variables: {without_useless_variables.shape}")


filtered_ground_truths_vs_predictions = without_useless_variables
#logger.info(filtered_ground_truths_vs_predictions.select(
#    pl.col("repository"),
#    pl.col(ExtendedInferredSchema.file),
#    pl.col(ExtendedInferredSchema.category),
#    pl.col(ExtendedInferredSchema.qname_ssa),
#    pl.col(ExtendedInferredSchema.anno),
#    pl.col("anno_predict"),
#).sample(n=20))

# Remove TT5 artifacts
filtered_ground_truths_vs_predictions = filtered_ground_truths_vs_predictions.filter(
    pl.col("anno_predict") != "..."
)


logger.info(filtered_ground_truths_vs_predictions.select(
    pl.col("repository"),
    pl.col(ExtendedInferredSchema.file),
    pl.col(ExtendedInferredSchema.category),
    pl.col(ExtendedInferredSchema.qname_ssa),
    pl.col(ExtendedInferredSchema.anno),
    pl.col("anno_predict"),
).sample(n=20))

# Dequalify all annotations as TypeT5 did
def dequalify(annotation: str | None) -> str | None:
    import libcst
    from libcst import matchers as m
    class Dequalifier(libcst.CSTTransformer):
        def __init__(self) -> None:
            super().__init__()

        def leave_Attribute(
            self, original_node: libcst.Attribute, updated_node: libcst.Attribute
        ) -> libcst.Name:
            return updated_node.attr
        
        def leave_Subscript(
            self,
            original_node: libcst.Subscript,
            updated_node: libcst.Subscript,
        ) -> libcst.Subscript:
            import typing
            from scripts.common import _stringify
            
            if not m.matches(updated_node, m.Subscript(value=m.Name("Union") | m.Attribute(m.Name(), m.Name("Union")))):
                return updated_node

            # Resort union entries
            subscript_elems_as_str = [_stringify(se.slice.value) for se in updated_node.slice]
            as_sorted = sorted(subscript_elems_as_str)

            sorted_union = libcst.parse_expression(
                f"Union[{','.join(as_sorted)}]"
            )

            return typing.cast(libcst.Subscript, sorted_union)
    
    if annotation is None:
        return None
    return libcst.parse_module(annotation).visit(Dequalifier()).code

filtered_ground_truths_vs_predictions = filtered_ground_truths_vs_predictions.with_columns([
    pl.col(ExtendedInferredSchema.anno).apply(lambda a: dequalify(a)).alias("rewritten_anno"),
    pl.col(ExtendedInferredSchema.parametric_anno).apply(lambda a: dequalify(a)).alias("rewritten_parametric_anno"),
    pl.col("anno_predict").apply(lambda a: dequalify(a)).alias("rewritten_anno_predict"),
    pl.col("parametric_anno_predict").apply(lambda a: dequalify(a)).alias("rewritten_parametric_anno_predict"),
])