In [1]:
from openai import OpenAI
from sklearn.cluster import KMeans
import torch
import torch.nn.functional as F
from collections import Counter
import numpy as np
import json
import random
import math
import csv
import os
from tqdm.notebook import tqdm
import tiktoken
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor, as_completed
from scipy.linalg import sqrtm
from numpy import iscomplexobj, trace, cov
from utils.compute_mauve import *

In [4]:
def random_sample_length():
    with open(length_path, 'r') as f:
        length_list = json.load(f)
    f.close()
    random_length = np.random.choice(length_list)
    return random_length

def random_sample_author():
    with open(author_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        author_list = [line.strip() for line in lines]
    f.close()
    random_author = np.random.choice(author_list)
    return random_author

def save_listdata_to_json(data, path):
    filedir = os.path.dirname(path)
    if not os.path.exists(filedir):
        os.makedirs(filedir)
    if not os.path.exists(path):
        with open(path, 'w', encoding='utf-8') as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        file.close()
    else:
        with open(path, 'r+', encoding='utf-8') as file:
            file_data = json.load(file)  
            file_data.extend(data)
            file.seek(0) 
            json.dump(file_data, file, ensure_ascii=False, indent=4)
        file.close()

def get_length(text):
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
    return len(encoding.encode(text))

def json2csv(json_file,csv_file):
    with open(json_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    with open(csv_file, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['text'])
        for item in data:
            clean_item = item.replace('\r', ' ').replace('\n', ' ').replace('\"', '')
            writer.writerow([clean_item]) 

def get_path(index, category, extension):
    return result_path + 'epoch_' + str(index) + '_' + category + '.' + extension

def load_json_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    f.close()
    return data

def transfer_blank(text,p):
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
    text_token = encoding.encode(text)
    new_text_token = [encoding.encode('_')[0] if random.random() < p else x for x in text_token]
    return encoding.decode(new_text_token)

def aggregate_data(index,lst):
    data_all = []
    for item in lst:
        data_all.extend(load_json_data(get_path(index, item, 'json')))
    save_listdata_to_json(data_all, get_path(index+1,'all','json'))

def cross_choice(data):
    lst_text = np.random.choice(data, 2, replace=False).tolist()
    return random.choice([lst_text, lst_text[::-1]])

def mutate_choice(data):
    return random.choice(data)

def generate_choice(data):
    return np.random.choice(data, min(3,len(data)), replace=False).tolist()

def count_selected_indices(part_sizes, selected_indices):
    count_selected = [0] * len(part_sizes)
    start_index = 0
    for i, size in enumerate(part_sizes):
        end_index = start_index + size
        count_selected[i] = sum(1 for index in selected_indices if start_index <= index < end_index)
        start_index = end_index
    return count_selected

def calculate_fid(embeddings1, embeddings2):
    mu1, sigma1 = embeddings1.mean(axis=0), cov(embeddings1, rowvar=False)
    mu2, sigma2 = embeddings2.mean(axis=0), cov(embeddings2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if iscomplexobj(covmean):
        covmean = covmean.real
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    print("FID: ", fid)
    return fid

def calculate_all_metrics(synthetic_embeddings, original_embeddings):
    method_name = ""
    p_feats = synthetic_embeddings  
    q_feats = original_embeddings
    result = compute_mauve(p_feats, q_feats)
    print("MAUVE: ", result.mauve)


def self_similarity(embeddings):
    similarity_matrix = cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
    mask = torch.ones_like(similarity_matrix) - torch.eye(similarity_matrix.size(0), device=similarity_matrix.device)
    masked_similarity_matrix = similarity_matrix * mask
    max_value, flat_index = masked_similarity_matrix.view(-1).max(0)  
    max_index = (flat_index // masked_similarity_matrix.size(1), flat_index % masked_similarity_matrix.size(1))
    average_similarity = masked_similarity_matrix.sum() / mask.sum()
    print(f"Average similarity of the dataset: {average_similarity.item()}")

In [5]:
def get_initial_population(syn_num,num_threads):
    def get_initdata(syn_num):
        ret_lst = []
        for i in range(syn_num):
            success = False
            while not success:
                try:
                    completion = client.chat.completions.create(
                        model="gpt-3.5-turbo",
                        temperature=1.2,
                        messages=[
                            {"role": "system", "content": "xxxxxxxxxxxxx"},
                            {"role": "user", "content": "xxxxxxxxxxxxx"},
                        ]
                    )
                    ret = completion.choices[0].message.content.replace("\n", " ")
                    if get_length(ret) < 80 or get_length(ret) > 1000:
                        continue
                    ret_lst.append(ret)
                    success = True
                except Exception as e:
                    print(f"Error on data generation for index {i+1}! Trying again...")
                    print(f"Error: {e}")
        return ret_lst

    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = [executor.submit(get_initdata, int(syn_num/num_threads)) for i in range(num_threads)]
        results = []
        with tqdm(total=num_threads, desc="get_initial_population") as pbar:
            for future in as_completed(futures):
                result = future.result()
                results.extend(result)
                pbar.update(1)
    save_listdata_to_json(results, get_path(1,'all','json'))


In [12]:

def filter_elite_choice(train_embeddings, embeddings_pt, num_choice=2000, sigma=None, threshold=0.73):
    distances = torch.cdist(train_embeddings, embeddings_pt)

    closest_indices = torch.argmin(distances, dim=1)

    votes = torch.zeros(embeddings_pt.shape[0], device=embeddings_pt.device)

    for idx in closest_indices:
        votes[idx] += 1
    if sigma is not None and sigma > 0:
        print("Adding noise with sigma =", sigma)
        noise = torch.normal(0, sigma, size=votes.shape, device=votes.device)
        votes += noise
        votes = torch.clamp(votes, min=0)
    else:
        print("No noise added!")
    top_values, top_indices = torch.topk(votes, embeddings_pt.shape[0])

    selected_indices = []
    temp=0
    current_threshold = threshold

    while len(selected_indices) < num_choice:
        
        for idx in top_indices:
            if len(selected_indices) >= num_choice:
                break
            if idx.item() in selected_indices:
                continue
            if selected_indices:
                selected_embeddings = embeddings_pt[selected_indices]
                similarity = F.cosine_similarity(embeddings_pt[idx].unsqueeze(0), selected_embeddings)
                if torch.any(similarity >= current_threshold):
                    continue

            selected_indices.append(idx.item())

        if len(selected_indices) < num_choice:
            current_threshold = round(current_threshold + 0.01, 2)
        temp = len(selected_indices)
    return selected_indices[:num_choice]

In [15]:
import torch

def select_tensors_by_probability(train_embeddings, probability):
    bernoulli_samples = torch.bernoulli(torch.full((train_embeddings.size(0),), probability))
    selected_tensors = train_embeddings[selected_indices]
    return selected_tensors


In [None]:
def mutate1(text):
    length = get_length(text)
    blank = transfer_blank(text,p)
    success = False
    while not success:
        try:
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=1.2,
                messages=[
                    {"role": "system", "content": "xxx"},
                    {"role": "user", "content": "xxx"},
                ]
            )
            ret = completion.choices[0].message.content.replace("\n", " ")
            if abs(length-get_length(ret)) / length > 0.8:
                continue
            success = True
            return ret
        except Exception as e:
            print(f"Error: {e}")

In [None]:
def mutate2(text, p=0.5):
    length = get_length(text)

    success = False
    while not success:
        try:
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=1.2,
                messages=[
                    {"role": "system", "content": "xxx"},
                    {"role": "user", "content": "xxx"},
                ]
            )
            ret = completion.choices[0].message.content.replace("\n", " ")
            if abs(length-get_length(ret)) / length > 0.8:
                continue
            success = True
            return ret
        except Exception as e:
            print(f"Error: {e}")

In [None]:
def cross(text):
    success = False
    while not success:
        try:
            length = get_length(text)
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=1.2,
                messages=[
                    {"role": "system", "content": "xxx"},
                    {"role": "user", "content": "xxx"},
                ]
            )
            ret = completion.choices[0].message.content.replace("\n", " ")
            if abs(length-get_length(ret)) / length > 0.8:
                continue
            success = True
            return ret
        except Exception as e:
            print(f"Error: {e}")

In [None]:
def generate():
    success = False
    while not success:
        try:
            length = random_sample_length()
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=1.2,
                messages=[
                    {"role": "system", "content": "xxx"},
                    {"role": "user", "content": "xxx"},
                ]
            )
            ret = completion.choices[0].message.content.replace("\n", " ")
            if get_length(ret) < 80 or get_length(ret) > 1000:
                continue

            success = True
            return ret
        except Exception as e:
            print(f"Error: {e}")

In [None]:
elite_num = 2000
syn_num = 10000
max_workers = 100
T = 10
sigma = 0 #eps=inf
sample_p = 1
filter_p = 0.73
result_path = 'xxxxx'

part_sizes = []
train_embeddings = torch.load(train_emb_path)

for i in range(1,T+1):
    client = OpenAI(base_url="xxxxxx",api_key="sk-xxxxxxx")
    if i == 1:
        if not os.path.exists(get_path(i, 'all', 'json')):
            get_initial_population(syn_num,max_workers)
        else:
            syn_now_num = len(load_json_data(get_path(i,'all','json')))
            if syn_now_num < syn_num:
                get_initial_population(syn_num-syn_now_num,syn_num-syn_now_num)
    data_all = load_json_data(get_path(i,'all','json'))
    if not os.path.exists(get_path(i, 'all_emb', 'pt')) or len(torch.load(get_path(i,'all_emb','pt')))!=len(data_all):
        get_embeddings(i,'all')

    embeddings_pt = torch.load(get_path(i,'all_emb','pt'))

    if sample_p != 1:
        selected_indices = filter_elite_choice(select_tensors_by_probability(train_embeddings, sample_p),embeddings_pt,elite_num,sigma,filter_p)
    else:
        selected_indices = filter_elite_choice(train_embeddings,embeddings_pt,elite_num,sigma,filter_p)
    if part_sizes != []:
        count_selected = count_selected_indices(part_sizes, selected_indices)
        print(count_selected)

    part_sizes = [len(selected_indices), syn_num-len(selected_indices)]
    if not os.path.exists(get_path(i,'elite','json')):
        save_listdata_to_json([data_all[indice] for indice in selected_indices],get_path(i,'elite','json'))

    if not os.path.exists(get_path(i, 'elite_emb', 'pt')):
        elite_embeddings_pt = embeddings_pt[selected_indices]
        torch.save(elite_embeddings_pt,get_path(i,'elite_emb','pt'))

    calculate_all_metrics(torch.load(get_path(i,'elite_emb','pt')).cpu().numpy(),torch.load(train_emb_path).cpu().numpy())
    calculate_fid(torch.load(get_path(i,'elite_emb','pt')).cpu().numpy(),torch.load(train_emb_path).cpu().numpy())
    if i == T: 
        break

    results_mutate1 = []
    results_mutate2 = []
    results_cross = []
    results_generate = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor_mutate:
        futures_mutate = [executor_mutate.submit(mutate1, data_all[indice]) for indice in selected_indices]
        with tqdm(total=elite_num, desc="mutate1") as pbar:
            for future_mutate in as_completed(futures_mutate):
                result_mutate = future_mutate.result()
                results_mutate1.append(result_mutate)
                pbar.update(1)
    save_listdata_to_json(results_mutate1, get_path(i,'mutate1','json'))
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor_mutate:
        futures_mutate = [executor_mutate.submit(mutate2, data_all[indice]) for indice in selected_indices]
        with tqdm(total=elite_num, desc="mutate2") as pbar:
            for future_mutate in as_completed(futures_mutate):
                result_mutate = future_mutate.result()
                results_mutate2.append(result_mutate)
                pbar.update(1)
    save_listdata_to_json(results_mutate2, get_path(i,'mutate2','json'))

    with ThreadPoolExecutor(max_workers=max_workers) as executor_cross:
        futures_cross = [executor_cross.submit(cross, data_all[indice]) for indice in selected_indices]
        with tqdm(total=elite_num, desc="cross") as pbar:
            for future_cross in as_completed(futures_cross):
                result_cross = future_cross.result()
                results_cross.append(result_cross)
                pbar.update(1)
    save_listdata_to_json(results_cross, get_path(i,'cross','json'))

    with ThreadPoolExecutor(max_workers=max_workers) as executor_generate:
        futures_generate = [executor_generate.submit(generate) for _ in range(elite_num)]
        with tqdm(total=elite_num, desc="generate") as pbar:
            for future_generate in as_completed(futures_generate):
                result_generate = future_generate.result()
                results_generate.append(result_generate)
                pbar.update(1)
    save_listdata_to_json(results_generate, get_path(i,'generate','json'))
    
    aggregate_data(i,['elite','mutate1','mutate2','cross','generate']) 

In [None]:
for i in range(10):
    embeddings = torch.load(get_path(i+1,'all_emb','pt'))
    indices = torch.randperm(embeddings.size(0))[:1000]
    random_samples = embeddings[indices]
    self_similarity(random_samples.to('cuda:0'))

In [None]:
for i in range(1,10):
    calculate_all_metrics(torch.load(get_path(i+1,'all_emb','pt'))[2000:].cpu().numpy(),torch.load(train_emb_path).cpu().numpy())
    calculate_fid(torch.load(get_path(i+1,'all_emb','pt'))[2000:].cpu().numpy(),torch.load(train_emb_path).cpu().numpy())