In [1]:
import re
from   pathlib import Path
from   typing import Callable, List, Literal, Tuple

import numpy as np
import pandas as pd
import torch

import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

from datasets.functions_dataset import FunctionsDataset
from datasets.utils.collate import Collater

from similarity_measures.wl_similarity import WL
from utils.lightning_module import PlagiarismModel

In [2]:
# handle parsing errors
with open("datasets/errors.txt") as f:
    parsing_errors = [l.strip() for l in f]

class FunDataset(Dataset):
    def __init__(self, data_root: str, f: Literal["graph", "tokens"], tensors: bool = True):
        self.data_root = Path(data_root)
        if f == "tokens":
            ext = ".R"            
        else:
            ext = ".R.txt"
        self.functions = sorted(list(self.data_root.glob(f"*{ext}")))
        self.functions = [f for f in self.functions if str(f) not in parsing_errors]
        self.bases = []
        pattern = re.compile(r"\d*$")
        for function_path in self.functions:
            base = pattern.split(function_path.name.replace(ext, ""))[0]
            self.bases.append(base)
        unique_bases = np.unique(self.bases)
        self.base2class = {base: i for i, base in enumerate(unique_bases)}
        self.labels = list(map(self.base2class.get, self.bases))
        self.format = f
        self.tensors = tensors

    def __len__(self) -> int:
        return len(self.functions)

    def __getitem__(self, index):
        path = self.functions[index]
        if self.format == "graph":
            function = FunctionsDataset.parse_graph(str(path), return_tensor=self.tensors)
        else:
            function = FunctionsDataset.tokenize(str(path), return_tensor=self.tensors,)
        return function, self.labels[index]

In [3]:
@torch.no_grad()
def compute_embs(fn: Callable, loader: DataLoader, cuda: bool = True):
    embeddings = []
    labels = []
    for fs, lbs in loader:
        labels.extend(lbs.flatten().tolist())
        if cuda:
            fs = fs.cuda()
        embs = fn(fs)
        embeddings.append(embs.cpu().numpy())

    return np.concatenate(embeddings), np.array(labels)

In [4]:
def maps_at_r(datasets, model, f, cuda=True):
    maps = []
    acc_calc = AccuracyCalculator(include=("mean_average_precision_at_r", "r_precision", "precision_at_1"))
    for dataset in tqdm(datasets):
        dataset = FunDataset(str(dataset), f, True)
        loader = DataLoader(dataset, 64, False, pin_memory=True, num_workers=10, collate_fn=Collater())
        embeddings, labels = compute_embs(model, loader, cuda)
        res = acc_calc.get_accuracy(embeddings, embeddings,
                                    labels, labels,
                                    embeddings_come_from_same_source=True)
        maps.append(res["mean_average_precision_at_r"])

    return {"map@r": np.mean(maps), "std": np.std(maps)}

## TOKENS

In [44]:
f = "tokens"
datasets = list(Path("data/text/").glob("**/10"))
classifier_path = "lightning_logs/version_502/checkpoints/epoch=17-step=17999.ckpt"
encoder_path    = "lightning_logs/version_499/checkpoints/epoch=24-step=24324.ckpt"
classifier = PlagiarismModel.load_from_checkpoint(classifier_path)
encoder    = PlagiarismModel.load_from_checkpoint(encoder_path)
classifier.eval().cuda()
encoder.eval().cuda();

In [45]:
maps_at_r(datasets, encoder.model, f)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 108/108 [04:36<00:00,  2.56s/it]


{'map@r': 0.9506968377086454, 'std': 0.04194855125595338}

In [46]:
maps_at_r(datasets, classifier.model, f)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 108/108 [04:32<00:00,  2.52s/it]


{'map@r': 0.9594160359561821, 'std': 0.034866103106324896}

## GRAPHS

In [5]:
f = "graph"
datasets = list(Path("data/graphs/").glob("**/10"))
classifier_path = "lightning_logs/version_472/checkpoints/epoch=10-step=1836.ckpt"
encoder_path    = "lightning_logs/version_460/checkpoints/epoch=34-step=5704.ckpt"
classifier = PlagiarismModel.load_from_checkpoint(classifier_path)
encoder    = PlagiarismModel.load_from_checkpoint(encoder_path)
classifier.eval().cuda()
encoder.eval().cuda();

In [48]:
maps_at_r(datasets, classifier.model, f)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 108/108 [02:32<00:00,  1.41s/it]


{'map@r': 0.9495405098199478, 'std': 0.04396203613427523}

In [49]:
maps_at_r(datasets, encoder.model, f)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 108/108 [02:29<00:00,  1.39s/it]


{'map@r': 0.9162808992161506, 'std': 0.057962361380680824}