In [None]:
import sys
import os
import pandas as pd
from tqdm.notebook import tqdm
from itertools import combinations, cycle
from collections import Counter
import warnings
import matplotlib.pyplot as plt
import torch

warnings.filterwarnings('ignore')
sys.path.append(os.path.abspath('..'))

from utils import username_to_repositorys, repository_to_shas, sha_to_detail, tokenizer_with_padding, tokens_transfomer

### Configuration

In [None]:
INPUT_TOKEN_SIZE = 4096

VOCAB = {
    '[PAD]':0,
    '[UNK]':1
}

REPOSITORYS = [
   'naturesh/code2vec'
]

token_cycle = cycle([
    os.environ.get('GITHUB_TOKEN'),
    os.environ.get('GITHUB_TOKEN_2')
])
TOKEN = next(token_cycle)

    

#### Load Repositories

In [None]:
X = []
failed = 0

repo_pgb = tqdm(REPOSITORYS, desc='Repositories')

for repo in repo_pgb:
        
    repo_pgb.set_postfix({'target':repo, 'failed': failed})

    try:
        shas = repository_to_shas(repo, token=TOKEN)
        
        sha_pgb = tqdm(shas, desc='Shas', leave=False)
        
        for sha in sha_pgb:
            details = sha_to_detail(repo, sha, token=TOKEN)
            X.extend(details)
            
        sha_pgb.close()
        
    except Exception as e:
        TOKEN = next(token_cycle)
        print('-', e)
        failed += 1



df = pd.DataFrame(X, columns=['name', 'text', 'ext'])
df = df.drop_duplicates(subset=['text'], keep='first')
df = df.reset_index(drop=True)


#### filter by extensions

In [None]:
allowed_extensions = ['.java']

filtered_df = df[df['ext'].isin(allowed_extensions)]

df = filtered_df.reset_index(drop=True)

#### under sampling based name

In [None]:

mean = int(df['name'].value_counts().mean())

def _apply(group):
    return group.sample(n=min(len(group), mean), random_state=42)

df = df.groupby('name').apply(_apply).reset_index(drop=True)

In [None]:
df.tail()

In [None]:
users = df['name'].value_counts().sort_index()

plt.figure(figsize=(5, 2))
plt.bar(range(len(users)), users.values, color='yellowgreen')
plt.xlabel('user distribution')
plt.ylabel('Number of Samples')
plt.xticks(range(len(users)))
plt.show()

#### CREATE SENTENCE TO EXT DICT 

In [None]:
EXT = {}

for _, row in df.iterrows():
    EXT[row['text']] = row['ext']

#### Create Contrastive Pair

In [None]:
pair = []

for ext, group in df.groupby('ext'):

    idx = group.index.tolist()

    for i1, i2 in combinations(idx, 2):

        col1 = df.loc[i1]
        col2 = df.loc[i2]

        label = 1 if col1['name'] == col2['name'] else -1

        pair.append([col1['text'],col2['text'], label, ext])


pair = pd.DataFrame(pair, columns=['p1', 'p2', 'label', 'ext'])
pair.tail()
        

#### Undersampling based label

In [None]:

min_count = pair['label'].value_counts().min()

pair = pair.groupby('label').apply(
    lambda x: x.sample(min_count, random_state=42)
).reset_index(drop=True)

print('distribution', pair['label'].value_counts())

#### Create VOCAB

In [None]:
MIN_TOKEN_FREQUENCY = 30

all_sentences = set(pair['p1'].tolist() + pair['p2'].tolist())
all_sentences_token = sum([tokenizer_with_padding(sentence, EXT[sentence], INPUT_TOKEN_SIZE) for sentence in all_sentences], [])

token_counter = Counter(all_sentences_token)

for token, count in token_counter.most_common():
    if count >= MIN_TOKEN_FREQUENCY and token not in VOCAB:
        VOCAB[token] = len(VOCAB)

print(f"전체 고유 토큰 수: {len(token_counter)}")
print(f"어휘에 포함된 토큰 수: {len(VOCAB)}")
print(f"UNK로 처리될 토큰 수: {len(token_counter) - len(VOCAB) + 1}")  # +1은 [UNK] 제외


#### Create Contrastive Dataset

In [None]:
def PreProcessing(text, ext=''):
    return tokens_transfomer(tokenizer_with_padding(text, EXT.get(text, ext), INPUT_TOKEN_SIZE), VOCAB)

p1_encoded = pair['p1'].apply(PreProcessing).tolist()
p2_encoded = pair['p2'].apply(PreProcessing).tolist()


In [None]:
x_train = torch.LongTensor([p1_encoded, p2_encoded]).permute(1, 0, 2)
y_train = torch.LongTensor(pair['label'].tolist())

In [None]:
x_train.shape, y_train.shape

In [None]:
exts = pair['ext'].value_counts()

plt.figure(figsize=(5, 2))
plt.bar(exts.keys(), exts.values, color='yellowgreen')
plt.xlabel('ext distribution')
plt.ylabel('Number of Samples')
plt.xticks(range(len(exts)))
plt.show()

In [None]:
torch.save(x_train, '../colab/train/x_train.pt')
torch.save(y_train, '../colab/train/y_train.pt')
torch.save(VOCAB,   '../colab/train/vocab.pt')

---

<br><br><br><br><br><br>