In [1]:
import os
from time import time
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from huggingface_hub import login
from sklearn.model_selection import StratifiedKFold, KFold
tqdm.pandas()

# Change the working directory to the directory containing the script
os.chdir("/group-volume/binfeng/wsdm/stage_qft")
from utils import *


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_PATH = "google/gemma-2-9b-it"
MAX_LENGTH = 2000
MAX_PROMPT_LENGTH = 400

## Tokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'
tokenizer.save_pretrained("/group-volume/binfeng/wsdm/tokenizer/gemma9b")

('/group-volume/binfeng/wsdm/tokenizer/gemma9b/tokenizer_config.json',
 '/group-volume/binfeng/wsdm/tokenizer/gemma9b/special_tokens_map.json',
 '/group-volume/binfeng/wsdm/tokenizer/gemma9b/tokenizer.model',
 '/group-volume/binfeng/wsdm/tokenizer/gemma9b/added_tokens.json',
 '/group-volume/binfeng/wsdm/tokenizer/gemma9b/tokenizer.json')

## Prepare Data

In [4]:
ft = pd.read_parquet("/group-volume/binfeng/wsdm/stage_qft/data/ft48k_calibrated.parquet")
ft.dropna(inplace=True)
ft["text"] = ft.apply(lambda x: format_text(tokenizer, x.prompt, x.response_a, x.response_b, 
                                                max_len=MAX_LENGTH, max_prompt_len=MAX_PROMPT_LENGTH), axis=1)
ft["label"] = ft.apply(lambda x: format_label(x.winner), axis=1)


In [5]:
soft = pd.read_parquet("/group-volume/binfeng/wsdm/stage_qft/data/soft87k.parquet")
soft.dropna(inplace=True)
soft["text"] = soft.apply(lambda x: format_text(tokenizer, x.prompt, x.response_a, x.response_b, 
                                                max_len=MAX_LENGTH, max_prompt_len=MAX_PROMPT_LENGTH), axis=1)
soft["label"] = soft.apply(lambda x: format_label(x.winner), axis=1)


In [6]:
skf = StratifiedKFold(n_splits=40, shuffle=True, random_state=66)
for train_index, val_index in skf.split(ft, ft["language"]):
    ft_train, ft_val = ft.iloc[train_index], ft.iloc[val_index]
    print(len(ft_train), len(ft_val))
    break


soft["logits_qwencd_cali"] = soft["logits_qwencd"]
soft["logits_qwen32_cali"] = soft["logits_qwen32"]
kf = KFold(n_splits=40, shuffle=True, random_state=66)
for train_index, val_index in kf.split(soft):
    soft_train, soft_val = soft.iloc[train_index], soft.iloc[val_index]
    print(len(soft_train), len(soft_val))
    break

47226 1211
85563 2194




In [7]:
def tokenizer_func(example):
    return tokenizer(
        example["text"], 
        padding='max_length', 
        max_length=MAX_LENGTH,
        truncation=True,
        return_tensors='np'
    )


ft_train_dataset = Dataset.from_pandas(ft_train[["text", "label", "logits_qwencd_cali", "logits_qwen32_cali"]])
ft_val_dataset = Dataset.from_pandas(ft_val[["text", "label", "logits_qwencd_cali", "logits_qwen32_cali"]])
soft_train_dataset = Dataset.from_pandas(soft_train[["text", "label", "logits_qwencd_cali", "logits_qwen32_cali"]])
soft_val_dataset = Dataset.from_pandas(soft_val[["text", "label", "logits_qwencd_cali", "logits_qwen32_cali"]])
raw_dataset = DatasetDict({
    'ft_train': ft_train_dataset,
    'ft_val': ft_val_dataset,
    'soft_train':soft_train_dataset,
    'soft_val': soft_val_dataset
})

tokenized_dataset = raw_dataset.map(tokenizer_func, batched=True)
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset = tokenized_dataset.remove_columns(['__index_level_0__', 'text'])
tokenized_dataset


Map: 100%|██████████| 47226/47226 [00:24<00:00, 1913.19 examples/s]
Map: 100%|██████████| 1211/1211 [00:00<00:00, 1590.01 examples/s]
Map: 100%|██████████| 85563/85563 [00:43<00:00, 1952.27 examples/s]
Map: 100%|██████████| 2194/2194 [00:00<00:00, 2325.00 examples/s]


DatasetDict({
    ft_train: Dataset({
        features: ['labels', 'logits_qwencd_cali', 'logits_qwen32_cali', 'input_ids', 'attention_mask'],
        num_rows: 47226
    })
    ft_val: Dataset({
        features: ['labels', 'logits_qwencd_cali', 'logits_qwen32_cali', 'input_ids', 'attention_mask'],
        num_rows: 1211
    })
    soft_train: Dataset({
        features: ['labels', 'logits_qwencd_cali', 'logits_qwen32_cali', 'input_ids', 'attention_mask'],
        num_rows: 85563
    })
    soft_val: Dataset({
        features: ['labels', 'logits_qwencd_cali', 'logits_qwen32_cali', 'input_ids', 'attention_mask'],
        num_rows: 2194
    })
})

In [10]:
i = 3
print(tokenizer.decode(tokenized_dataset["soft_val"][i]["input_ids"], skip_special_tokens=False))
print("**label:", tokenized_dataset["soft_val"][i]["labels"])

<bos><|User Prompt|>
<bos>Please write a small echo TCP server in Python.

<|Response A|>
<bos>Below is a simple **Echo TCP Server** implemented in Python. This server listens for incoming TCP connections, receives data from clients, and sends the same data back to them, effectively "echoing" the input.

### Echo TCP Server in Python

```python
import socket
import threading

def handle_client(client_socket, client_address):
    print(f"[+] New connection from {client_address}")
    try:
        while True:
            # Receive data from the client (buffer size: 1024 bytes)
            data = client_socket.recv(1024)
            if not data:
                # No data received, client has closed the connection
                print(f"[-] Connection closed by {client_address}")
                break
            print(f"[{client_address}] Received: {data.decode().strip()}")
            
            # Echo the received data back to the client
            client_socket.sendall(data)
    ex

In [9]:
tokenized_dataset.save_to_disk("/group-volume/binfeng/wsdm/stage_qft/dataset/tokenized_gemma9b")

Saving the dataset (1/1 shards): 100%|██████████| 47226/47226 [00:00<00:00, 222681.83 examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 47226/47226 [00:00<00:00, 149622.43 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1211/1211 [00:00<00:00, 92372.78 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 85563/85563 [00:00<00:00, 115858.71 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2194/2194 [00:00<00:00, 101858.49 examples/s]
