In [None]:
import sys
from utils.helpers import *
from dataset import CustomImageFolder
from torchvision import transforms
from functools import partial
import torch
from functions import *
import matplotlib.pyplot as plt
from clustering_function import text_low,text_mid,text_high
from torchvision.datasets import ImageFolder
import os
import torch
from clustering_function import cluster
from sklearn.manifold import TSNE
from sklearn.metrics import homogeneity_score,silhouette_score
import csv
from sklearn.metrics import pairwise_distances_argmin_min
from setting import *


# model and data setting
set_seeds(2024)
model_choice = "hidden"
data_choice = "COCO"
mi = ""
ckp_dir = f"weights/{mi}weight/{data_choice}/{model_choice}"
ckp_list = sorted(os.listdir(ckp_dir))
if f"{mi}OO" in ckp_list:
    ckp_list.remove(f"{mi}OO")
    ckp_list.insert(0,f"{mi}OO")
device = torch.device("cuda:7")
data_dir = {"COCO":MSCOCO_TEST_PATH,"CelebA":CELEBAHQ_VAL_PATH}
transform_pipe = [
    transforms.Resize((image_resolution[data_choice],image_resolution[data_choice])),
    transforms.ToTensor(),
]


# forge setting 
set_seeds(2024)
text_num = 4
random_sample = False
csv_file = f"../KM_results/oriscore/{data_choice}_{model_choice}_{'rand' if random_sample else ''}_{text_num}_{mi}.csv"
if model_choice == 'hidden':
    transform_pipe.append(transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]))
transform = transforms.Compose(transform_pipe)
if data_choice == "CelebA":
    ds = ImageFolder(data_dir[data_choice],transform=transform)
else:
    ds = CustomImageFolder(data_dir[data_choice],transform,num=1000)
message_len = 100

# functions
get_wmimages = partial(get_wmimages,ds=ds,device=device,model_choice=model_choice)
get_images = partial(get_images,ds=ds,device=device)
extract_message = partial(extract_message,device=device,model_choice=model_choice)

# tsne forge setting
use_vae = False
batch_size = 16
attack_num = 30


text_list = [generate_message(message_len,t,1) for t in text_mid]
for op in ckp_list:
    if op == "OO":
        ckp_path = os.path.join(ckp_dir,op,'epoch_499_state100.pth')
    elif op == "emperical":
        ckp_path = os.path.join(ckp_dir,op,'epoch_499_state.pth')
    else:
        ckp_path = os.path.join(ckp_dir,op,'epoch_99_state.pth')
    encoder,decoder = load_weights(ckp_path,model_choice,message_len)
    # watermarked images by different text
    
    if random_sample:
        wm_images,residual_predictions = zip(*[get_wmimages(text=t,image_i=random.sample(range(1000),attack_num),encoder=encoder) for i,t in enumerate(text_list)])
    else:
        wm_images,residual_predictions = zip(*[get_wmimages(text=t,image_i=range((i)*attack_num,(i+1)*attack_num),encoder=encoder) for i,t in enumerate(text_list)])
    # 
    if use_vae:
        residual_predictions = [get_residual_prediction(wm_images_i,batch_size,device,method="VAE") for wm_images_i in wm_images]

    wm_images = list(wm_images)
    residual_predictions = list(residual_predictions)

    # TSNE
    data = torch.cat(residual_predictions,dim=0).detach().cpu()
    labels = torch.cat([torch.ones(attack_num,)*i for i in range(text_num)]).detach().numpy()
    flattened_data = data.view(text_num*attack_num,-1).numpy()

    tsne = TSNE(n_components=2, random_state=5000)
    reduced_data = tsne.fit_transform(flattened_data)

    # cluster
    cluster_labels,cluster_centers = cluster(reduced_data,text_num,"km")
    # score = homogeneity_score(labels,cluster_labels)
    score = silhouette_score(reduced_data,cluster_labels)

    _, distances = pairwise_distances_argmin_min(reduced_data, cluster_centers)
    average_distances = []
    for i in range(text_num):
        cluster_distances = distances[cluster_labels == i]
        average_distance = np.mean(cluster_distances)
        average_distances.append(average_distance)

    for i, avg_distance in enumerate(average_distances):
        print(f"Cluster {i}: Average Distance to Centroid = {avg_distance:.2f}")
    print(f"model:{op}-score:{score}")

    csv_dict = {"model":op,"score":score}
    with open(csv_file, mode='a', newline='') as file:
        fieldnames = csv_dict.keys()
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        if file.tell() == 0:
            writer.writeheader()
        writer.writerow(csv_dict)
