In [None]:
import sys
import csv
from utils.helpers import *
from dataset import CustomImageFolder
from torchvision import transforms
from functools import partial
import torch
from functions import *
from clustering_function import text_low,text_mid,text_high
from torchvision.datasets import ImageFolder
from clustering_function import cluster
from sklearn.manifold import TSNE
from setting import *
from sklearn.metrics import silhouette_score


# model and data setting
mi = ""
device = torch.device("cuda:7")
normalize_ = transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
un_normalize_ = transforms.Normalize([-1,-1,-1],[2,2,2])
data_dir = {"COCO":MSCOCO_TEST_PATH,"CelebA":CELEBAHQ_VAL_PATH}

# forge setting 
text_num = 4         # 1 4
# set_seeds(2024)
message_len = 100
tolerant = cal_tolerant(message_len)
batch_size = 32
alpha = 1
use_cluster = True
use_vae = False
attack_num = 30
random_sample = True

for model_choice in ["hidden"]:
    for data_choice in ["COCO"]:
        transform_pipe = [
            transforms.Resize((image_resolution[data_choice],image_resolution[data_choice])),
            transforms.ToTensor(),
        ]
        ckp_dir = f"weights/{mi}weight/{data_choice}/{model_choice}"
        if model_choice == 'hidden':
            transform_pipe.append(transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]))
        transform = transforms.Compose(transform_pipe)
        if data_choice == "CelebA":
            ds = ImageFolder(data_dir[data_choice],transform=transform)
        else:
            ds = CustomImageFolder(data_dir[data_choice],transform,num=1000)

        # functions
        get_wmimages = partial(get_wmimages,ds=ds,device=device,model_choice=model_choice)
        get_images = partial(get_images,ds=ds,device=device)
        extract_message = partial(extract_message,device=device,model_choice=model_choice)

        ckp_list = sorted(os.listdir(ckp_dir))
        if f"{mi}OO" in ckp_list:
            ckp_list.remove(f"{mi}OO")
            ckp_list.insert(0,f"{mi}OO")


        for gpname,text_strings in zip(["mid"],[text_mid]):
            csv_file = f"../forged_results/oriforge/{data_choice}_{model_choice}_{gpname}_{text_num}_{mi}_alpha{alpha}_{attack_num}.csv"
            text_list = [generate_message(message_len,t,1) for t in text_strings]
            for op in ckp_list:
                if op == "OO":
                    ckp_path = os.path.join(ckp_dir,op,'epoch_499_state100.pth')
                elif op == 'emperical':
                    ckp_path = os.path.join(ckp_dir,op,'epoch_499_state.pth')
                else:
                    ckp_path = os.path.join(ckp_dir,op,'epoch_99_state.pth')
                encoder,decoder = load_weights(ckp_path,model_choice,message_len)
                encoder.eval(),decoder.eval()
                target_image,target_residual = get_wmimages(text=text_list[-1],image_i=[-1],encoder=encoder)
                # watermarked images by different text
                bit_acc_list = []
                acc_list = []
                if random_sample:
                    wm_images,residual_predictions = zip(*[get_wmimages(text=t,image_i=random.sample(range(1000),attack_num),encoder=encoder) for i,t in enumerate(text_list)])
                else:
                    wm_images,residual_predictions = zip(*[get_wmimages(text=t,image_i=range((i)*attack_num,(i+1)*attack_num),encoder=encoder) for i,t in enumerate(text_list)])
                if use_vae:
                    residual_predictions = [get_residual_prediction(wm_images_i,batch_size,device=device,method='VAE') for wm_images_i in wm_images]
                wm_images = list(wm_images)
                residual_predictions = list(residual_predictions)
                residual_predictions.append(target_residual)

                # # TSNE
                if text_num != 1:
                    data = torch.cat(residual_predictions,dim=0).detach().cpu()
                    flattened_data = data.view((text_num)*attack_num+1,-1).numpy()
                    tsne = TSNE(n_components=2, random_state=5000)
                    reduced_data = tsne.fit_transform(flattened_data)

                    # cluster
                    cluster_labels,cluster_centers = cluster(reduced_data,text_num,"km")
                    score = silhouette_score(reduced_data,cluster_labels)
                    print("score:",score)
                    target_label = cluster_labels[-1]
                    if use_cluster:
                        attack_samples = data[cluster_labels == target_label]
                    else:
                        attack_samples = residual_predictions[-2]
                        print(attack_samples.shape)

                    # rescale
                    attack_samples = un_normalize_(attack_samples)  # 0-1
                    attack_sample = attack_samples.mean(axis=0)
                    attack_sample = normalize_(attack_sample) # -1,1
                    attack_pattern = attack_sample.unsqueeze(0).to(device)

                else:
                    attack_sample = residual_predictions[-1]
                    attack_pattern = attack_sample.to(device)

                # forge
                original_images = get_images(image_i=range(900,1000)).to(device)
                forged_image = (original_images + alpha * attack_pattern)
                forged_image = torch.clamp(forged_image, min=-1, max=1)
                forged_prediction,_ = extract_message(image_tensor=forged_image,decoder=decoder,model_choice=model_choice)
                gt_text = text_list[-1].repeat(forged_prediction.size(0),1).to(device)
                difference = (forged_prediction != gt_text).float()
                correct = (difference.sum(dim=1)<=tolerant).float()
                acc = correct.mean(dim=0).item()
                bitwise_accuracy = (1.0 - difference.mean(dim=1))
                bitwise_accuracy = torch.mean(bitwise_accuracy)
                print(f"model:{op}-forged_bit_acc:{bitwise_accuracy.item()} acc:{acc}")
                csv_dict = {"model":op,"bit_acc":bitwise_accuracy.item(),"acc":acc}
                with open(csv_file, mode='a', newline='') as file:
                    fieldnames = csv_dict.keys()
                    writer = csv.DictWriter(file, fieldnames=fieldnames)
                    if file.tell() == 0:
                        writer.writeheader()
                    writer.writerow(csv_dict)
 
    

score: 0.6765941
model:GN_0.25-forged_bit_acc:0.4631999731063843 acc:0.0
score: 0.6304311
model:GN_0.25-forged_bit_acc:0.46380001306533813 acc:0.0
score: 0.6517542
model:GN_0.25-forged_bit_acc:0.4616999626159668 acc:0.0
score: 0.6943655
model:GN_0.25-forged_bit_acc:0.46309995651245117 acc:0.0
score: 0.7044579
model:GN_0.25-forged_bit_acc:0.46490001678466797 acc:0.0
score: 0.6594108
model:GN_0.25-forged_bit_acc:0.46220001578330994 acc:0.0
score: 0.35920802
model:GN_0.25-forged_bit_acc:0.46220001578330994 acc:0.0
score: 0.71154994
model:GN_0.25-forged_bit_acc:0.46470001339912415 acc:0.0
score: 0.62132776
model:GN_0.25-forged_bit_acc:0.4602999687194824 acc:0.0


KeyboardInterrupt: 

In [None]:
import pandas as pd


def process_csv(file1, file2, output_file,score):
    # 读取两个CSV文件
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)
    if score == 'forgedacc' or score == 'xyscore':
        score = 'bit_acc'
    # 去掉第二个文件中的`mi`前缀
    df2['model'] = df2['model'].str.replace('mi', '', regex=False)
    
    # 提取 `bit_acc` 列并重命名
    bit_acc1 = df1[['model', score]].rename(columns={score: 'wo_bit_acc'})
    bit_acc2 = df2[['model', score]].rename(columns={score: 'w_bit_acc'})
    
    # 合并数据
    merged_df = pd.merge(bit_acc1, bit_acc2, on='model', how='left')
    
    # 如果在第二个文件中没有出现的模型，其 w_bit_acc 设为 wo_bit_acc
    merged_df['w_bit_acc'].fillna(merged_df['wo_bit_acc'], inplace=True)
    
    # 格式化结果
    merged_df['wo_bit_acc'] = merged_df['wo_bit_acc'].apply(lambda x: f"{x:.4f}")
    merged_df['w_bit_acc'] = merged_df['w_bit_acc'].apply(lambda x: f"{x:.4f}")
    
    # 保存到新的CSV文件
    merged_df.to_csv(output_file, index=False)
    

score = "xyscore"
for dataset in ["CelebA"]:
    for model in ["stega"]:
        if score == "forgedacc":
            csv_file1 = f"/data/shared/Huggingface/sharedcode/Stegastamp_CR/forged_results/{dataset}_{model}_mid_4__alpha1_30.csv"
            csv_file2 = f"/data/shared/Huggingface/sharedcode/Stegastamp_CR/forged_results/{dataset}_{model}_mid_4_mi_alpha1_30.csv"
            output_file = f'/data/shared/Huggingface/sharedcode/Stegastamp_CR/forged_results/comparsion_{dataset}_{model}_{score}.csv'
        elif score == "score":
            csv_file1 = f"/data/shared/Huggingface/sharedcode/Stegastamp_CR/KM_results/score/{dataset}_{model}_rand_4_.csv"
            csv_file2 = f"/data/shared/Huggingface/sharedcode/Stegastamp_CR/KM_results/score/{dataset}_{model}_rand_4_mi.csv"
            output_file = f'/data/shared/Huggingface/sharedcode/Stegastamp_CR/KM_results/score/comparsion_{dataset}_{model}_{score}.csv'
        else:
            csv_file1 = f"/data/shared/Huggingface/sharedcode/Stegastamp_CR/xyattackresults/{dataset}_{model}__1_3.csv"
            csv_file2 = f"/data/shared/Huggingface/sharedcode/Stegastamp_CR/xyattackresults/{dataset}_{model}_mi_1_3.csv"
            output_file = f'/data/shared/Huggingface/sharedcode/Stegastamp_CR/xyattackresults/comparsion_{dataset}_{model}_{score}_1_3.csv'
        process_csv(csv_file1, csv_file2, output_file,score)


