In [None]:
import sys
sys.path.append("..")

import datetime
import random
import math
import time
import json
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display, HTML
import plotly
plotly.io.templates.default = "plotly_dark"
import plotly.express as px
import pandas as pd

from src.datasets import *
from src.util.image import *
from src.util import *
from src.util.files import *
from src.util.embedding import *
from src.algo import *
from src.models.encoder import *
from src.models.decoder import *
from src.models.util import *
from src.util.text_encoder import TextEncoder
from src.util.gharchive import GHArchive

In [None]:
encoder = TextEncoder("bytefreq")

In [None]:
def iter_messages(
    dates: Iterable[datetime.date] = (
        datetime.date(2018, 1, 7),
        datetime.date(2021, 1, 7),
        datetime.date(2022, 1, 7),
        datetime.date(2023, 11, 20)
    ),
    message_buffer_size: int = 1_000_000,
    prob: float = 1.,
    encodings: Optional[torch.Tensor] = None,
    encoding_weights: Optional[torch.Tensor] = None,
):
    gharchive = GHArchive(verbose=True)
    
    if encodings is not None and encoding_weights is None:
        assert encodings.ndim == 2, encodings.ndim
        encoding_weights = torch.ones(encodings.shape[0])
    
    def _iter_events():
        iterables = [
            gharchive.iter_events(
                day=date, 
                event_type="PushEvent",
                probability=prob,
            )
            for date in dates
        ]
        while iterables:
            next_iterables = []
            for it in iterables:
                try:
                    yield next(it)
                    next_iterables.append(it)
                    
                except StopIteration:
                    pass
            iterables = next_iterables
            
    num_skipped = 0
    num_yielded = 0
    
    def _iter_messages():
        nonlocal num_skipped
        
        message_dict = {}
        num_processed = 0
        with tqdm() as progress:
            for event in _iter_events():            
                for commit in event["payload"]["commits"]:
                    message = commit["message"]
                    if message in message_dict:
                        num_skipped += 1
                        continue
                    
                    message_dict[message] = num_processed
                    num_processed += 1
                    
                    yield commit["message"]

                progress.update(1)
                progress.desc = (
                    f"messages/skips {num_yielded:,}/{num_skipped:,}"
                    f", buffer-size {len(message_dict):,}"
                    f", date={event['created_at']}"
                )

                if len(message_dict) >= message_buffer_size:
                    median = sorted(message_dict.values())
                    # print("min/median/max", median[0], median[len(median) // 2], median[-1])
                    median = median[len(median) // 2]

                    message_dict = {
                        msg: step
                        for msg, step in message_dict.items()
                        if step <= median
                    }
                    # print("reduced buffer to", len(message_dict))
    
    if encodings is None:
        for message in _iter_messages():
            yield message
            num_yielded += 1
    else:    
        for batch in iter_batches(_iter_messages(), 128):
            with torch.no_grad():
                enc = encoder.encode(batch)
            
            sim_matrix = enc @ encodings.T
            # print(sim_matrix)
            for text, sims in zip(batch, sim_matrix):
                skip = False
                #print(len(batch), sim.shape, encoding_weights.shape)
                for sim, weight in zip(sims, encoding_weights):
                    if (weight >= 0 and sim < weight) or (weight < 0 and sim > -weight):
                        skip = True
                        break
                
                if skip:
                    num_skipped += 1
                else:
                    yield text
                    num_yielded += 1
                
try:
    count = 0
    for m in iter_messages(
            prob=1/10, 
            #encodings=normalize_embedding(cluster_centers[5:6]), encoding_weights=torch.Tensor([.95]),
    ):
        if count < 10:
            count += 1
            print(repr(m))
            
except KeyboardInterrupt:
    pass

# cluster

In [None]:
from sklearn import cluster as skcluster

total = 10000
num_clusters = 10

try:
    texts = []
    for m, i in zip(iter_messages(
            #prob=1. / 10,
            #encodings=normalize_embedding(cluster_centers[5:6]), encoding_weights=torch.Tensor([.9]),
            encodings=normalize_embedding(cluster_lib["short_texts"].to_numpy()[None, :]), encoding_weights=torch.Tensor([.9]),
    ), range(total)):
        texts.append(m)
        if len(texts) <= 10:
            print(repr(texts[-1]))
            
    
except KeyboardInterrupt:
    pass

with torch.no_grad():
    embeddings = encoder.encode(texts).cpu().numpy()

print("clustering:", embeddings.shape)

#clusterer = skcluster.KMeans(num_clusters, n_init="auto")
#clusterer = skcluster.BisectingKMeans(num_clusters)
clusterer = skcluster.SpectralClustering(num_clusters)
labels = clusterer.fit_predict(embeddings)

hist = np.histogram(labels, bins=num_clusters, range=(0, num_clusters))[0]
px.bar(hist)

## get cluster centers

In [None]:
if hasattr(clusterer, "cluster_centers_"):
    cluster_centers = clusterer.cluster_centers_

else:
    cluster_centers = []
    for ci in range(num_clusters):
        c_indices = np.argwhere(labels == ci)[:, 0]
        c_embeddings = embeddings[c_indices]
        cluster_centers.append(c_embeddings.mean(axis=0)[None, :])
    cluster_centers = normalize_embedding(np.concatenate(cluster_centers))

df = pd.DataFrame(cluster_centers.T)
df["char"] = df.index.map(lambda i: chr(i) if 32 <= i < 128 else f"0x{i:02x}")
px.line(
    df, 
    title=f"normalized byte frequencies of {len(texts):,} github commits, {num_clusters} cluster centers",
    hover_data=["char"],
)

# display

In [None]:
num_examples = 5

for ci in range(num_clusters):
    
    c_indices = np.argwhere(labels == ci)[:, 0]
    c_embeddings = embeddings[c_indices]
    c_center = cluster_centers[ci]

    c_sim = normalize_embedding(cluster_centers[ci]) @ c_embeddings.T 
    c_sim_idx = np.argsort(c_sim)
    
    c_title = f"cluster #{ci} -- {c_indices.shape[0] / len(texts) * 100:.2f}% ({c_indices.shape[0]} entries)"
    display(HTML(f"""<h3>{c_title}</h3>"""))
    display(px.bar(c_center, height=300, title=c_title))
    
    print(" -- best match --")
    for i in reversed(c_sim_idx[-num_examples:]):
        idx, sim = c_indices[i], c_sim[i]
        print(f"  {sim:.3f} {repr(texts[idx][:100])}")
    print("\n -- worst match --")
    for i in c_sim_idx[:num_examples]:
        idx, sim = c_indices[i], c_sim[i]
        print(f"  {sim:.3f} {repr(texts[idx][:100])}")
    
    display(HTML("<hr/>"))
    

In [None]:
cluster_lib = pd.DataFrame()
cluster_lib = pd.read_csv("./bytefreqs.csv", index_col=0)
cluster_lib

In [None]:
cluster_lib["short_texts"] = cluster_centers[3]

In [None]:
cluster_lib

In [None]:
cluster_lib.to_csv("./bytefreqs.csv")

In [None]:
cluster_lib["short_texts"].to_numpy()

In [None]:
def encode_bytes(texts: Iterable[str], normalize: bool, with_numpy: bool) -> torch.Tensor:
    import numpy as np

    tensors = []
    for text in texts:
        
        if not with_numpy:
            values = [0] * 256
            for ch in text.encode():
                values[ch] += 1
                
            tensors.append(torch.Tensor(values).unsqueeze(0))
        else:
            byte_array = np.frombuffer(text.encode(), dtype=np.uint8)
            hist = np.histogram(byte_array, 256, (0, 256))[0]

            tensors.append(hist[None, :])
    
    if with_numpy:
        tensors = torch.Tensor(np.concatenate(tensors))
    else:
        tensors = torch.concat(tensors)
    if normalize:
        tensors = normalize_embedding(tensors)

    return tensors

texts = ["Assignment2: Regression and Classifier models\n\nIn this assignment I used supervised learning models."] * 10000

start_time = time.time()
encode_bytes(texts, True, False)
print(f"{time.time() - start_time:.3f}")
start_time = time.time()
encode_bytes(texts, True, True)
print(f"{time.time() - start_time:.3f}")