In [None]:
import torch
dataset_type = "CIFAR10"
sample_2D = True
sample_208D = False
save_origin_pic = False

device = torch.device("cuda:0")

#### 是否从2D空间中采样

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.utils as utils
import numpy as np
from scipy.stats import norm
import torch

# 插值函数
def get_zs_prevent_stick(coordinates, kdTree_2D, latent_z, k=10):
    '''
    nearest_distance: n*k维
    nearest_index: n*k维
    dict_zs: 键是文件的id号, 值是对应的z（后面直接改成数组了，问题不大）
    '''
    print("进入了防止粘在一块~~~~~~~~~~~~~~~~~~")
    # 直接一次查询所有坐标的k个近邻
    nearest_distance, nearest_index = kdTree_2D.query(coordinates, k=k)
    origin_coordinates = kdTree_2D.data # 获取kdtree中原始的坐标
    for i, pos in enumerate(coordinates): #对每一个坐标进行插值
        # pos坐标对应的近邻下标
        pos_nearst_index = nearest_index[i] # 其中有k个index，每个index对应kdTree_2D中的一个2维坐标
        pos_nearst_distance = nearest_distance[i]

        # 最近邻的坐标点，以及最近的距离
        most_nearst_pos = origin_coordinates[pos_nearst_index[0]]
        most_nearst_dis = pos_nearst_distance[0]

        # 利用三角形,找到第二个插值基点，让两边之和越接近第三边，就越是钝角，就越合理
        s1 = most_nearst_dis
        best_index = 1 # 默认第二个最近邻最好
        min_dif = 100
        for j in range(1, k):
            cur_pos = origin_coordinates[pos_nearst_index[j]]
            s2 = pos_nearst_distance[j]
            s3 = np.linalg.norm(most_nearst_pos-cur_pos)
            if (s1 + s2) - s3 < min_dif: # 两边之和大于等于第三边，所以不用绝对值
                min_dif = (s1 + s2) - s3
                best_index = j
        
        temp_z_0 = latent_z[pos_nearst_index[0]].clone().detach()
        temp_z_1 = latent_z[pos_nearst_index[best_index]].clone().detach()
        sum_distance = most_nearst_dis + pos_nearst_distance[best_index]
        z_new = (sum_distance-most_nearst_dis)/(sum_distance) * temp_z_0 + (sum_distance-pos_nearst_distance[best_index])/(sum_distance) * temp_z_1

        z_new = z_new.unsqueeze(0)
        if i == 0:
            zs = z_new
        else:
            zs = torch.cat((zs, z_new), dim=0)

    return zs


# 用来处理zs的类，方便使用batchsize
class Mydata_sets(Dataset):
    
    def __init__(self, zs):
        super(Mydata_sets, self).__init__()
        self.zs = zs

    def __getitem__(self, index):
        z = self.zs[index]
        return z

    def __len__(self):
        return len(self.zs)

if sample_2D:
    
    # 读取二维坐标数据
    kdTree = torch.load("./static/data/CIFAR10/2D_kdTree/2D_kdTree_50000_png.pt")
    coords = kdTree.data

    # 计算坐标点的协方差矩阵
    cov = np.cov(coords.T)

    # 生成符合多元正态分布的新坐标点
    new_coords = np.random.multivariate_normal(np.mean(coords, axis=0), cov, size=50000)

    latent_z_path="./static/data/CIFAR10/latent_z/BigGAN_random_50k_png_208z_50000.pt"
    latent_z = torch.load(latent_z_path, map_location="cpu") #因为我之前保存数据到了GPU上，所以要回到cpu上才不会出错    

    from scipy import spatial
    norm_tree =  spatial.KDTree(data=new_coords)

    zs = get_zs_prevent_stick(new_coords, kdTree, latent_z)

    zs_datasets = Mydata_sets(zs)
    zs_loader = DataLoader(zs_datasets, batch_size=200, shuffle=False, num_workers=1)

    model_files_dir = "./model_files/" # 模型位置
    sys.path.append(model_files_dir)
    import model_files as model_all
    checkpoints_path = "./model_files/CIFAR10/checkpoints/BigGAN/model=G-best-weights-step=392000.pth"
    G = model_all.get_generative_model("CIFAR10").to(device)
    G.load_state_dict(torch.load(checkpoints_path, map_location=device)["state_dict"])
    G.eval()

    first = 0 # 判断是否第一次进入循环
    count = 0
    with torch.no_grad(): # 取消梯度计算，加快运行速度
        for batch_z in zs_loader: 
            z = torch.tensor(batch_z).to(torch.float32).to(device)    # latent code
            imgs = G(z)   
            for i, img in enumerate(imgs):
                img = ((img + 1)/2).clamp(0.0, 1.0) # 变换到[0,1]范围内
                utils.save_image(img.detach().cpu(), f'./临时垃圾-随时可删/2D_50k_png/pic_{count}.png')
                count += 1