In [None]:
import sys
sys.path.append("..")

import datetime
import random
import math
import time
import json
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display
import plotly
plotly.io.templates.default = "plotly_dark"
import plotly.express as px

from src.datasets import *
from src.util.image import *
from src.util import *
from src.util.files import *
from src.util.embedding import *
from src.algo import *
from src.models.encoder import *
from src.models.decoder import *
from src.models.util import *

In [None]:
from src.util.gharchive import GHArchive
gharchive = GHArchive(verbose=True)

In [None]:
def iter_commits(
    start_date: datetime.date = datetime.date(2023, 11, 20),
    days: int = 7,
    message_buffer_size: int = 1_000_000,
):
    gharchive = GHArchive(verbose=False)
    
    iterables = [
        gharchive.iter_events(
            day=start_date + datetime.timedelta(i), 
            event_type="PushEvent",
        )
        for i in range(days)
    ]
    iterable = itertools.chain(*iterables)

    message_dict = {}
    num_skipped = 0
    num_yielded = 0
    
    with tqdm() as progress:
        for event in iterable:
            #print(json.dumps(event, indent=2))
            data = {
                "repo": event["repo"]["name"],
                "date": event["created_at"],
            }
            for commit in event["payload"]["commits"]:
                message = commit["message"]
                if message in message_dict:
                    num_skipped += 1
                    continue

                message_dict[message] = num_yielded

                yield {
                    **data,
                    "sha": commit["sha"],
                    "message": commit["message"],
                }
                num_yielded += 1
        
            progress.update(1)
            progress.desc = (
                f"messages/skips {num_yielded:,}/{num_skipped:,}"
                f", buffer-size {len(message_dict):,}"
                f", date={data['date']}"
            )
            
            if len(message_dict) >= message_buffer_size:
                median = sorted(message_dict.values())
                # print("min/median/max", median[0], median[len(median) // 2], median[-1])
                median = median[len(median) // 2]
                
                message_dict = {
                    msg: step
                    for msg, step in message_dict.items()
                    if step <= median
                }
                # print("reduced buffer to", len(message_dict))
                
            
for c in iter_commits():
    pass

In [None]:
import gzip

def write_messages(
    filename
):
    with gzip.open(filename, "wt") as fp:
        try:
            for commit in iter_commits():
                fp.write(json.dumps(commit, separators=(',', ':')) + "\n")
        
        except KeyboardInterrupt:
            pass

write_messages(Path("~/prog/data/gharchive/commits.ndjson.gz").expanduser())


In [None]:
1_700_000 * 512 // 1024 // 1024

In [None]:
from src.models.clip import ClipSingleton
ClipSingleton.encode_text?

In [None]:
text_batch = []
encodings = []
texts = []
try:
    for commit in tqdm(iter_ndjson(Path("~/prog/data/gharchive/commits.ndjson.gz").expanduser()), total=3_500_000):
        text_batch.append(commit["message"])
        if len(text_batch) >= 128:
            encodings.append(
                normalize_embedding(
                    ClipSingleton.encode_text(text_batch, truncate=True)
                ).cpu().float()
            )
            texts.extend(text_batch)
            text_batch.clear()
except KeyboardInterrupt:
    pass

encodings = torch.concat(encodings)
encodings.shape

In [None]:
sim = encodings @ encodings.T

In [None]:
px.imshow(sim[:100, :100], height=700)

In [None]:
matching_indices = sim.argsort(dim=-1, descending=True)
for row in matching_indices[:10]:
    print()
    print(repr(texts[row[0]]))
    for idx in row[1:10]:
        print("  ", repr(texts[idx]))

In [None]:
@torch.no_grad()
def find_messages(text: str, count: int = 20):
    embedding = normalize_embedding(ClipSingleton.encode_text(text, truncate=True)).cpu().float()
    sim = embedding @ encodings.T
    best_indices = sim.argsort(dim=-1, descending=True)
    print(repr(text))
    for idx in best_indices[0, :count]:
        print(f"   {sim[0, idx]:.3f}", repr(texts[idx]))
    
find_messages("tiredness")

In [None]:
texts[:10]

In [None]:
s = "⚡ 优化功能选项信息"
[hex(ord(c)) for c in s]

In [None]:
from scripts.github_commits import load_data
commits, embeddings = load_data(
    Path("~/prog/data/gharchive").expanduser(), datetime.date(2023, 11, 20), 100, "thenlper/gte-small",
)
embeddings.shape

In [None]:
limit = 5000

from sklearn.manifold import TSNE
reducer = TSNE(2, verbose=1, perplexity=20)
positions = reducer.fit_transform(embeddings[:limit])

px.scatter(x=positions[:limit, 0], y=positions[:limit, 1], height=1000)

In [None]:
from sklearn.cluster import KMeans
clusterer = KMeans(50, n_init="auto")
labels = clusterer.fit_predict(embeddings[:limit])

In [None]:
px.scatter(
    x=positions[:limit, 0], y=positions[:limit, 1], height=1000, 
    #color=[min(1000, len(c["message"])) for c in commits[:limit]]
    hover_data={"repo": [c["repo"] for c in commits[:limit]]},
    color=[str(c) for c in labels],
    #color_continuous_scale="rainbow",
)

In [None]:
px.scatter?