In [5]:



import pickle
import os
import numpy as np
from tqdm import tqdm
from PIL import Image  
from openTSNE import TSNE
import matplotlib.pyplot as plt
import copy
import pandas as pd
import seaborn as sns
import random

import torch
import torch.nn as nn
import torchvision
import torchvision.utils as utils
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset

import h5py



device = torch.device("cuda:1") 
dataset_type = "CIFAR10"
model_name = "null"
save_origin_pic = False # 是否保留原始图片
save_generate_pic = False # 是否保留生成图片
fid_judge = False # 是否进行fid评估



import sys
python_files_dir = "./python_files/" # python工具包位置
sys.path.append(python_files_dir)
import my_tools
import fid_score as official_fid

model_files_dir = "./model_files/" # 模型位置
sys.path.append(model_files_dir)
import model_files as all_model


if dataset_type == "CIFAR10":
    transform = transforms.Compose([
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Resize((224, 224), interpolation=Image.BICUBIC),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

if dataset_type == "CIFAR10":
    ResNet50 = torch.load("./model_files/CIFAR10/checkpoints/classify_model/ResNet50.pt", map_location=device) # 这个pt文件里不仅仅是参数，包括了模型
    ResNet50 = ResNet50.to(device)
    ResNet50.eval()
    train_dataloader = DataLoader(datasets.CIFAR10('./static/data/CIFAR10/CIFAR10/', train=False, download=True, transform=transform), batch_size=128, shuffle=False)

class Mydata_sets(Dataset):
    
    def __init__(self, path, transform=None):
        super(Mydata_sets, self).__init__()
        self.root_dir = path
        self.img_names = os.listdir(self.root_dir)
        self.img_names.sort(key = lambda x:int(x[4:-4]))
        print(self.img_names[:100])
        self.transform = transform

    def __getitem__(self, index):
        img_name = self.img_names[index]
        img = Image.open(os.path.join(self.root_dir, img_name))
        id_name = torch.tensor(int(img_name[4:-4])) #pic_xx.jpg
        if self.transform is not None:
            img = self.transform(img)
        return img, id_name

    def __len__(self):
        return len(self.img_names)

if dataset_type == "CIFAR10":
    transform = transforms.Compose([
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Resize((224, 224), interpolation=Image.BICUBIC),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

index = 0

# with torch.no_grad():
#     train_dataloader = tqdm(train_dataloader)
#     total = 0
#     correct = 0
#     for inputs, labels in train_dataloader:
#             inputs,labels = inputs.to(device), labels.to(device)
#             # ============= forward =============
#             outputs = ResNet50(inputs)
#             # ============= precision ===========
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#             #break
#             description = 'correct: %.4f, total: %.4f , accuracy: %.4f' % (correct, total, correct/total)
#             train_dataloader.set_description(description)
#             train_dataloader.update()
# print("最终结果： ", description)

features, labels, ids = [], [], [] # features：提取的2048维图片特征，labels：模型预测的标签，ids图片文件编号

feature_model = copy.deepcopy(ResNet50)
feature_model.fc = nn.Identity() # 相当于取消fc层, 这样
label_model = copy.deepcopy(ResNet50)

# 图片路径
if dataset_type == "CIFAR10":
    pic_path = "./static/data/CIFAR10/pic/random_50k_png"   

img_datasets = Mydata_sets(pic_path, transform=transform)  
# img_datasets = torchvision.datasets.CIFAR10("./static/data/CIFAR10/", train=True, download=False, transform=transform) #使用原始的cifar10图片
imgLoader = torch.utils.data.DataLoader(img_datasets, batch_size=128, shuffle=False, num_workers=4)  # 指定读取配置信息

with torch.no_grad():
    for x, y in tqdm(imgLoader):
        x = x.to(device)
        ids.append(y)  # N
        feature = feature_model(x)  # N, 2048
        features.append(feature)

        ten_D = label_model(x)
        label = torch.argmax(ten_D, dim=1)
        labels.append(label)

features = torch.cat(features, dim=0).squeeze().cpu().numpy()  # (n, 2048)
labels = torch.cat(labels, dim=0).squeeze().cpu().numpy()  # n
ids = torch.cat(ids, dim=0).cpu().numpy() # n

print(features.shape)
print(labels.shape)
print(ids.shape)

Files already downloaded and verified
['pic_0.png', 'pic_1.png', 'pic_2.png', 'pic_3.png', 'pic_4.png', 'pic_5.png', 'pic_6.png', 'pic_7.png', 'pic_8.png', 'pic_9.png', 'pic_10.png', 'pic_11.png', 'pic_12.png', 'pic_13.png', 'pic_14.png', 'pic_15.png', 'pic_16.png', 'pic_17.png', 'pic_18.png', 'pic_19.png', 'pic_20.png', 'pic_21.png', 'pic_22.png', 'pic_23.png', 'pic_24.png', 'pic_25.png', 'pic_26.png', 'pic_27.png', 'pic_28.png', 'pic_29.png', 'pic_30.png', 'pic_31.png', 'pic_32.png', 'pic_33.png', 'pic_34.png', 'pic_35.png', 'pic_36.png', 'pic_37.png', 'pic_38.png', 'pic_39.png', 'pic_40.png', 'pic_41.png', 'pic_42.png', 'pic_43.png', 'pic_44.png', 'pic_45.png', 'pic_46.png', 'pic_47.png', 'pic_48.png', 'pic_49.png', 'pic_50.png', 'pic_51.png', 'pic_52.png', 'pic_53.png', 'pic_54.png', 'pic_55.png', 'pic_56.png', 'pic_57.png', 'pic_58.png', 'pic_59.png', 'pic_60.png', 'pic_61.png', 'pic_62.png', 'pic_63.png', 'pic_64.png', 'pic_65.png', 'pic_66.png', 'pic_67.png', 'pic_68.png', 'pic_

100%|██████████| 391/391 [01:48<00:00,  3.62it/s]


(50000, 2048)
(50000,)
(50000,)


In [6]:
torch.save(labels, "./临时垃圾-随时可删/labels.pt")

[1 9 5 2 8 2 7 3 7 3]
