In [1]:
from openai import OpenAI
import torch
import torch.nn.functional as F
from torch.nn.functional import cosine_similarity
from collections import Counter
import numpy as np
import pandas as pd
import json
import random
import itertools
import csv
import os
from tqdm.notebook import tqdm,trange
import tiktoken
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor, as_completed
from scipy.linalg import sqrtm
from numpy import iscomplexobj, trace, cov
from utils.compute_mauve import *

In [3]:
def random_sample_length():
    with open(length_path, 'r') as f:
        length_list = json.load(f)
    f.close()
    random_length = np.random.choice(length_list)
    return random_length

def random_sample_subcategory(c):
    with open(subcategory_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    f.close()
    random_subcategory = random.choices(data[c]['Subcategories'], weights=data[c]['Probabilities'], k=1)[0]
    return random_subcategory

def splitstr(text,strr):
    return text.split(strr)

def save_listdata_to_json(data, path):
    filedir = os.path.dirname(path)
    if not os.path.exists(filedir):
        os.makedirs(filedir)
    if not os.path.exists(path):
        with open(path, 'w', encoding='utf-8') as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        file.close()
    else:
        with open(path, 'r+', encoding='utf-8') as file:
            file_data = json.load(file)  
            file_data.extend(data) 
            file.seek(0)  
            json.dump(file_data, file, ensure_ascii=False, indent=4)
        file.close()

def get_length(text):
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
    return len(encoding.encode(text))

def json2csv(json_file,label_file,csv_file):

    with open(json_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    with open(label_file, 'r', encoding='utf-8') as f:
        labels = json.load(f)

    with open(csv_file, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f, quoting=csv.QUOTE_ALL, escapechar='\\')
        writer.writerow(['text','label1','label2'])
        for item, label in zip(data,labels):
            clean_item = str(item)
            label1 = label.split('_')[0]
            label2 = label.split('_')[1]
            writer.writerow([clean_item,label1,label2]) 


def get_path(index, category, extension):
    return result_path + 'epoch_' + str(index) + '_' + category + '.' + extension

def load_json_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    f.close()
    return data

def transfer_blank(text,p):
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
    text_token = encoding.encode(text)
    new_text_token = [encoding.encode('_')[0] if random.random() < p else x for x in text_token]
    return encoding.decode(new_text_token)

def aggregate_data(index,lst,num):
    data = []
    for item in lst:
        data.extend(load_json_data(get_path(index, item, 'json')))
    chunks = [data[i:i + num] for i in range(0, len(data), num)]
    data_all = list(itertools.chain.from_iterable(zip(*chunks)))
    save_listdata_to_json(data_all, get_path(index+1,'all','json'))

def cross_choice(data):
    lst_text = np.random.choice(data, 2, replace=False).tolist()
    return random.choice([lst_text, lst_text[::-1]])

def mutate_choice(data):
    return random.choice(data)

def generate_choice(data):
    return np.random.choice(data, min(3,len(data)), replace=False).tolist()

def count_selected_indices(part_sizes, selected_indices):
    count_selected = [0] * len(part_sizes)
    start_index = 0
    for i, size in enumerate(part_sizes):
        end_index = start_index + size
        count_selected[i] = sum(1 for index in selected_indices if start_index <= index < end_index)
        start_index = end_index
    return count_selected

def calculate_fid(embeddings1, embeddings2):
    mu1, sigma1 = embeddings1.mean(axis=0), cov(embeddings1, rowvar=False)
    mu2, sigma2 = embeddings2.mean(axis=0), cov(embeddings2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if iscomplexobj(covmean):
        covmean = covmean.real
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    print("FID: ", fid)
    return fid

def calculate_all_metrics(synthetic_embeddings, original_embeddings):
    method_name = ""
    p_feats = synthetic_embeddings  
    q_feats = original_embeddings
    result = compute_mauve(p_feats, q_feats)
    print("MAUVE: ", result.mauve)


def self_similarity(embeddings):
    similarity_matrix = cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
    mask = torch.ones_like(similarity_matrix) - torch.eye(similarity_matrix.size(0), device=similarity_matrix.device)
    masked_similarity_matrix = similarity_matrix * mask
    max_value, flat_index = masked_similarity_matrix.view(-1).max(0)  
    max_index = (flat_index // masked_similarity_matrix.size(1), flat_index % masked_similarity_matrix.size(1))
    average_similarity = masked_similarity_matrix.sum() / mask.sum()
    print(f"Average similarity of the dataset: {average_similarity.item()}")


In [4]:
def get_initial_population(syn_list, max_workers, combined_label, client):
    def get_initdata(label):
        label1 = label.split("_")[0]
        label2 = label.split("_")[1]
        success = False
        length = random_sample_length()
        subcategory = random_sample_subcategory(label1)
        closest_ret = ""
        closest_diff = float('inf')
        n=0
        while not success and n < 3:
            try:
                completion = client.chat.completions.create(
                    model="gpt-3.5-turbo",
                    temperature=1.2,
                    messages=[
                        {"role": "system", "content": "xxxxxxxxxxxxxxxxxx"},
                        {"role": "user", "content": "xxxxxxxxxxxxxxxxxxxxx"},
                    ]
                )
                ret = completion.choices[0].message.content
                actual_length = get_length(ret)
                length_diff = abs(actual_length - length)

                if length * 0.8 <= actual_length <= length * 1.2:
                    success = True
                    return ret
                else:
                    if length_diff < closest_diff:
                        closest_ret = ret
                        closest_diff = length_diff
                    n += 1
            except Exception as e:
                print(f"Error: {e}")
        return closest_ret
    label_list = []
    syn_num = sum(syn_list)
    for j in range(len(syn_list)):
        for k in range(syn_list[j]):
            label_list.append(combined_label[j])
    if not os.path.exists(get_path(1,'all_label','json')):
        save_listdata_to_json(label_list, get_path(1,'all_label','json'))

    results = [None] * len(label_list)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(get_initdata, label): i for i, label in enumerate(label_list)}

        with tqdm(total=syn_num, desc="get_initial_population") as pbar:
            for future in as_completed(futures):
                result = future.result()
                index = futures[future]
                results[index] = result  
                pbar.update(1)

    save_listdata_to_json(results, get_path(1, 'all', 'json'))

In [7]:
def filter_elite_choice_per_class(train_embeddings, embeddings_pt, class_indices, num_choice, sigma=None, threshold=0.80):
    train_embeddings = train_embeddings.to('cuda')
    embeddings_pt = embeddings_pt[class_indices].to('cuda')
    class_indices = torch.tensor(class_indices).long().to('cuda')
    distances = torch.cdist(train_embeddings, embeddings_pt)
    closest_indices = torch.argmin(distances, dim=1)
    votes = torch.bincount(closest_indices, minlength=embeddings_pt.size(0))
    
    if sigma is not None and sigma > 0:
        print("Adding noise with sigma =", sigma)
        noise = torch.normal(0, sigma, size=votes.shape, device=votes.device)
        votes += noise
        votes = torch.clamp(votes, min=0)
    else:
        pass
    top_values, top_indices = torch.topk(votes, embeddings_pt.shape[0])

    selected_indices = []
    temp=0
    current_threshold = threshold

    while len(selected_indices) < num_choice:
        
        for idx in top_indices:
            if len(selected_indices) >= num_choice:
                break
            if idx.item() in selected_indices:
                continue
            if selected_indices:
                selected_embeddings = embeddings_pt[selected_indices]
                similarity = F.cosine_similarity(embeddings_pt[idx].unsqueeze(0), selected_embeddings)
                if torch.any(similarity >= current_threshold):
                    continue

            selected_indices.append(idx.item())

        if len(selected_indices) < num_choice:
            current_threshold = round(current_threshold + 0.01, 2)
        temp = len(selected_indices)
    
    selected_indices = class_indices[selected_indices].tolist()
    return selected_indices

In [None]:
import torch

def select_tensors_by_probability(train_embeddings, probability):
    bernoulli_samples = torch.bernoulli(torch.full((train_embeddings.size(0),), probability))
    selected_tensors = train_embeddings[selected_indices]
    return selected_tensors


In [None]:
def mutate1(text):
    length = get_length(text)
    blank = transfer_blank(text,p)
    success = False
    while not success:
        try:
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=1.2,
                messages=[
                    {"role": "system", "content": "xxx"},
                    {"role": "user", "content": "xxx"},
                ]
            )
            ret = completion.choices[0].message.content.replace("\n", " ")
            if abs(length-get_length(ret)) / length > 0.8:
                continue
            success = True
            return ret
        except Exception as e:
            print(f"Error: {e}")

In [None]:
def mutate2(text, p=0.5):
    length = get_length(text)

    success = False
    while not success:
        try:
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=1.2,
                messages=[
                    {"role": "system", "content": "xxx"},
                    {"role": "user", "content": "xxx"},
                ]
            )
            ret = completion.choices[0].message.content.replace("\n", " ")
            if abs(length-get_length(ret)) / length > 0.8:
                continue
            success = True
            return ret
        except Exception as e:
            print(f"Error: {e}")

In [None]:
def cross(text):
    success = False
    while not success:
        try:
            length = get_length(text)
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=1.2,
                messages=[
                    {"role": "system", "content": "xxx"},
                    {"role": "user", "content": "xxx"},
                ]
            )
            ret = completion.choices[0].message.content.replace("\n", " ")
            if abs(length-get_length(ret)) / length > 0.8:
                continue
            success = True
            return ret
        except Exception as e:
            print(f"Error: {e}")

In [None]:
def generate(label):
    success = False
    while not success:
        try:
            length = random_sample_length()
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=1.2,
                messages=[
                    {"role": "system", "content": "xxx"},
                    {"role": "user", "content": "xxx"},
                ]
            )
            ret = completion.choices[0].message.content.replace("\n", " ")
            if get_length(ret) < 80 or get_length(ret) > 1000:
                continue

            success = True
            return ret
        except Exception as e:
            print(f"Error: {e}")

In [None]:
elite_num = 2000
syn_num = 10000
max_workers = 1000
T = 10
sigma = 0 #eps=inf
sample_p = 1
filter_p = 1
result_path = '../data/openreview/dpga_result/xxx/'
if not os.path.exists(result_path):
    os.makedirs(result_path)

if sample_p == 1:
    train_embeddings = torch.load(train_emb_path)
    df = pd.read_csv(train_csv_path, sep=',')
else:
    train_embeddings, df = select_tensors_by_probability(torch.load(train_emb_path), pd.read_csv(train_csv_path, sep=','), sample_p)

df['combined_label'] = df['label1'].astype(str) + "_" + df['label2'].astype(str)
combined_label = {index: label for index, label in enumerate(list(dict.fromkeys(df['combined_label'].values)))}
with open(class_path, 'r') as f:
    class_list = json.load(f)
f.close()
elite_list = allocate_elite_numbers(elite_num, class_list)
syn_list = [x*5 for x in elite_list]

combined_label_inv = {label: index for index, label in enumerate(list(dict.fromkeys(df['combined_label'].values)))}
unique_labels = list(combined_label_inv.keys())
if len(unique_labels) != len(class_list):
    exit("The number of unique labels does not match the number of elite numbers!")

for i in range(1,T+1):
    client = OpenAI(base_url="xxxx",api_key="sk-xxxxxx")
    if i == 1:
        if not os.path.exists(get_path(i, 'all', 'json')):
            get_initial_population(syn_list, max_workers, combined_label, client)
    data_all = load_json_data(get_path(i,'all','json'))
    if not os.path.exists(get_path(i, 'all_emb', 'pt')) or len(torch.load(get_path(i,'all_emb','pt')))!=len(data_all):
        get_embeddings(i,'all')

    embeddings_pt = torch.load(get_path(i,'all_emb','pt'))
    labels_list = load_json_data(get_path(1,'all_label','json'))

    selected_indices = []
    for class_label, k in zip(unique_labels, elite_list):
        train_class_indices = [i for i, label in enumerate(df['combined_label'].values) if label == class_label]
        class_indices = [i for i, label in enumerate(labels_list) if label == class_label]
        selected_indices.extend(filter_elite_choice_per_class(train_embeddings[train_class_indices], embeddings_pt, class_indices, k, sigma, filter_p))

    if not os.path.exists(get_path(i,'elite','json')):
        save_listdata_to_json([data_all[indice] for indice in selected_indices],get_path(i,'elite','json'))#保存精英集
        save_listdata_to_json([labels_list[indice] for indice in selected_indices],get_path(i,'elite_label','json'))
        json2csv(get_path(i,'elite','json'),get_path(i,'elite_label','json'),get_path(i,'elite','csv'))

    if not os.path.exists(get_path(i, 'elite_emb', 'pt')):
        elite_embeddings_pt = embeddings_pt[selected_indices]
        torch.save(elite_embeddings_pt,get_path(i,'elite_emb','pt'))

    calculate_all_metrics(torch.load(get_path(i,'elite_emb','pt')).cpu().numpy(),torch.load(train_emb_path).cpu().numpy())
    calculate_fid(torch.load(get_path(i,'elite_emb','pt')).cpu().numpy(),torch.load(train_emb_path).cpu().numpy())
    if i == 1: 
        break

    results_mutate = [None] * elite_num
    with ThreadPoolExecutor(max_workers=max_workers) as executor_mutate:
        futures_mutate = {executor_mutate.submit(mutate1, data_all[indice],labels_list[indice]): index for index, indice in enumerate(selected_indices)}
        with tqdm(total=elite_num, desc="mutate1") as pbar:
            for future_mutate in as_completed(futures_mutate):
                result_mutate = future_mutate.result()  
                index = futures_mutate[future_mutate] 
                results_mutate[index] = result_mutate  
                pbar.update(1)  
    save_listdata_to_json(results_mutate, get_path(i, 'mutate1', 'json'))

    results_mutate = [None] * elite_num
    with ThreadPoolExecutor(max_workers=max_workers) as executor_mutate:
        futures_mutate = {executor_mutate.submit(mutate2, data_all[indice],labels_list[indice]): index for index, indice in enumerate(selected_indices)}
        with tqdm(total=elite_num, desc="mutate2") as pbar:
            for future_mutate in as_completed(futures_mutate):
                result_mutate = future_mutate.result() 
                index = futures_mutate[future_mutate]  
                results_mutate[index] = result_mutate 
                pbar.update(1) 
    save_listdata_to_json(results_mutate, get_path(i, 'mutate2', 'json'))

    results_cross = [None] * elite_num
    with ThreadPoolExecutor(max_workers=max_workers) as executor_cross:

        futures_cross = {executor_cross.submit(cross, data_all[indice],labels_list[indice]): index for index, indice in enumerate(selected_indices)}
        with tqdm(total=elite_num, desc="cross") as pbar:
            for future_cross in as_completed(futures_cross):
                result_cross = future_cross.result()  
                index = futures_cross[future_cross]  
                results_cross[index] = result_cross  
                pbar.update(1) 
    save_listdata_to_json(results_cross, get_path(i, 'cross', 'json'))

    results_generate = [None] * elite_num
    with ThreadPoolExecutor(max_workers=max_workers) as executor_generate:

        futures_generate = {executor_generate.submit(generate, data_all[indice],labels_list[indice]): index for index, indice in enumerate(selected_indices)}
        with tqdm(total=elite_num, desc="generate") as pbar:
            for future_generate in as_completed(futures_generate):
                result_generate = future_generate.result() 
                index = futures_generate[future_generate] 
                results_generate[index] = result_generate 
                pbar.update(1)  
    save_listdata_to_json(results_generate, get_path(i, 'generate', 'json'))
    
    aggregate_data(i,['elite','mutate1','mutate2','cross','generate']) 
