In [4]:
import os
from tqdm import tqdm

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    return images

def find_classes(directory):
    """Finds the class folders in a dataset.
    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

In [5]:
import cv2
import random
def create_new_datasets(path, new_sample_path, new_real_path, size):
    classes, class_to_idx = find_classes(path)
    for class_name in tqdm(classes):
        srcPath = path + class_name
        #print(srcPath)
        if len(os.listdir(srcPath)) < size:
            S_sample = os.listdir(srcPath)
        else:
            S_sample = random.sample(os.listdir(srcPath), k=size)
        S_real = set(os.listdir(srcPath)) - set(S_sample)
        
        for image in tqdm(S_sample):
            img = cv2.imread(srcPath + '/' + image)
            #print(img)
            dirs = new_sample_path + class_name
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            if img is not None and is_image_file(image):
                cv2.imwrite(dirs + '/' + image, img)
                '''
        
        for image in tqdm(S_real):
            img = cv2.imread(srcPath + '/' + image)
            #print(img)
            dirs = new_real_path + class_name
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            if img is not None  and is_image_file(image):
                cv2.imwrite(dirs + '/' + image, img)
                '''
                

In [6]:
def copy_dataset(path, new_path):
    # copy all images in this dictionary to new_path
    images = make_dataset(path)
    print(len(images))
    for image in tqdm(images):
        img = cv2.imread(image)
        #print(image)
        image_name = image.split('/')[-2] + '_' + image.split('/')[-1]
        if not os.path.exists(new_path):
            os.makedirs(new_path)
        if img is not None:
            cv2.imwrite(new_path + '/' + image_name, img)
        
        

In [7]:
import numpy as np
import random
from tqdm import tqdm
import os
import shutil
def split_dataset(path, new_path, size):
    #classes, class_to_idx = find_classes(path+'train/')
    #test_classes = random.sample(classes, k=size)
    classes, class_to_idx = find_classes(path+'/train_and_valid/')
    class_dic = {}
    for one_class in classes:
        class_path = path + '/train_and_valid/' + one_class
        images = make_dataset(class_path)
        class_dic[one_class] = len(images)

    #print(class_dic)
    sortedDict = {k: v for k, v in sorted(class_dic.items(), key=lambda item: -item[1])}
    #print(sortedDict)

    res = sum(sortedDict.values()) / len(sortedDict)
    # printing result
    print("Average number of images per class : " + str(res))

    
    test_classes = random.sample(classes, k=size)
    images_per_class = []
    for test_class in test_classes:
        images_per_class.append(class_dic[test_class])
    res = sum(images_per_class)/len(images_per_class)
    print("Average number of images per class of selected class: " + str(res))
    train_classes = set(classes)-set(test_classes)
    for test_class in tqdm(test_classes):
        source_dir = path + '/valid/' + str(test_class)
        destination_dir = new_path + '/test/' + str(test_class)
        shutil.copytree(source_dir, destination_dir)
    
        source_dir = path + '/train/' + str(test_class)
        destination_dir = new_path + '/test/' + str(test_class)
        images = make_dataset(source_dir)
        for image in images:
            img = cv2.imread(image)
            #print(image)
            image_name = image.split('/')[-2] + '_' + image.split('/')[-1]
            if not os.path.exists(new_path):
                os.makedirs(new_path)
            if img is not None:
                cv2.imwrite(destination_dir + '/' + image_name, img)
    print("Finish creating test set!")
    for train_class in tqdm(train_classes):
        source_dir = path + '/valid/' + str(train_class)
        destination_dir = new_path + '/train/' + str(train_class)
        shutil.copytree(source_dir, destination_dir)

        source_dir = path + '/train/' + str(train_class)
        destination_dir = new_path + '/train/' + str(train_class)
        images = make_dataset(source_dir)
        for image in images:
            img = cv2.imread(image)
            #print(image)
            image_name = image.split('/')[-2] + '_' + image.split('/')[-1]
            if not os.path.exists(new_path):
                os.makedirs(new_path)
            if img is not None:
                cv2.imwrite(destination_dir + '/' + image_name, img)
    print("Finish creating train set!")
    

In [8]:
def select_subset(path, new_path, size):
    classes, class_to_idx = find_classes(path)
    test_classes = random.sample(classes, k=size)
    train_classes = set(classes)-set(test_classes)
    print(len(test_classes))
    print(len(train_classes))
    for i in range(size):
        test_class = classes[i]
        source_dir = path + '/' + str(test_class)
        destination_dir = new_path + '/' + str(test_class)
        shutil.copytree(source_dir, destination_dir)
        print('finish {}/{}'.format(i, size))
    '''
    for test_class in tqdm(test_classes):
        source_dir = path + 'valid/' + str(test_class)
        destination_dir = new_path + 'test/' + str(test_class)
        shutil.copytree(source_dir, destination_dir)
        '''
    print("Finish creating new dataset!")


In [9]:
import random

def select_best_test_set(path, new_path, size, classes, sortedDict):
    '''
    classes, class_to_idx = find_classes(path)
    class_dic = {}
    for one_class in classes:
        class_path = path + '/' + one_class
        images = make_dataset(class_path)
        class_dic[one_class] = len(images)

    #print(class_dic)
    sortedDict = {k: v for k, v in sorted(class_dic.items(), key=lambda item: -item[1])}
    #print(sortedDict)

    res = sum(sortedDict.values()) / len(sortedDict)
    # printing result
    #print("Average number of images per class : " + str(res))
    '''
    '''
    count = 0
    test_classes = []
    for key in sortedDict:
        count += 1
        test_classes.append(key)
        if count >= size:
            break
    print(test_classes)
    #print(sortedDict)
    #print(sortedDict[:size])
    '''
    
    test_classes = random.sample(classes, k=size)
    images_per_class = []
    for test_class in test_classes:
        images_per_class.append(class_dic[test_class])
    res = sum(images_per_class)/len(images_per_class)
    #print("Average number of images per class of selected class: " + str(res))
    train_classes = set(classes)-set(test_classes)
    #print(len(test_classes))
    #print(len(train_classes))
        
    '''
    for test_class in tqdm(test_classes):
        source_dir = path + '/' + str(test_class)
        destination_dir = new_path + '/test/' + str(test_class)
        shutil.copytree(source_dir, destination_dir)
        
    print("Finish creating test set!")

    for train_class in tqdm(train_classes):
        source_dir = path + '/' + str(train_class)
        destination_dir = new_path + '/train/' + str(train_class)
        shutil.copytree(source_dir, destination_dir)
        
    print("Finish creating train set!")
    '''
    return res

In [10]:
import random

def select_best_test_set_animals(path, new_path, size):
    
    classes, class_to_idx = find_classes(path)
    class_dic = {}
    for one_class in classes:
        class_path = path + '/' + one_class
        images = make_dataset(class_path)
        class_dic[one_class] = len(images)

    #print(class_dic)
    sortedDict = {k: v for k, v in sorted(class_dic.items(), key=lambda item: -item[1])}
    #print(sortedDict)

    res = sum(sortedDict.values()) / len(sortedDict)
    # printing result
    print("Average number of images per class : " + str(res))

    
    test_classes = random.sample(classes, k=size)
    images_per_class = []
    for test_class in test_classes:
        images_per_class.append(class_dic[test_class])
    res = sum(images_per_class)/len(images_per_class)
    print("Average number of images per class of selected class: " + str(res))
    train_classes = set(classes)-set(test_classes)
    #print(len(test_classes))
    #print(len(train_classes))
        
    
    for test_class in tqdm(test_classes):
        source_dir = path + '/' + str(test_class)
        destination_dir = new_path + '/test/' + str(test_class)
        shutil.copytree(source_dir, destination_dir)
        
    print("Finish creating test set!")

    for train_class in tqdm(train_classes):
        source_dir = path + '/' + str(train_class)
        destination_dir = new_path + '/train/' + str(train_class)
        shutil.copytree(source_dir, destination_dir)
        
    print("Finish creating train set!")
    
    return res

In [11]:
import random

def select_best_test_set_faces(path, new_path, size):
    
    classes, class_to_idx = find_classes(path)
    class_dic = {}
    for one_class in classes:
        class_path = path + '/' + one_class
        images = make_dataset(class_path)
        class_dic[one_class] = len(images)

    #print(class_dic)
    sortedDict = {k: v for k, v in sorted(class_dic.items(), key=lambda item: -item[1])}
    #print(sortedDict)

    res = sum(sortedDict.values()) / len(sortedDict)
    # printing result
    print("Average number of images per class : " + str(res))

    count = 0
    selected_classes = []
    for key in sortedDict:
        count += 1
        selected_classes.append(key)
        if count >= size+10:
            break
    #print(test_classes)
    #print(sortedDict)
    #print(sortedDict[:size])

    
    test_classes = random.sample(selected_classes, k=size)
    images_per_class = []
    for test_class in test_classes:
        images_per_class.append(class_dic[test_class])
    res = sum(images_per_class)/len(images_per_class)
    print("Average number of images per class of selected class: " + str(res))
    train_classes = set(classes)-set(test_classes)
    #print(len(test_classes))
    #print(len(train_classes))
        
    
    for test_class in tqdm(test_classes):
        source_dir = path + '/' + str(test_class)
        destination_dir = new_path + '/test/' + str(test_class)
        shutil.copytree(source_dir, destination_dir)
        
    print("Finish creating test set!")

    for train_class in tqdm(train_classes):
        source_dir = path + '/' + str(train_class)
        destination_dir = new_path + '/train/' + str(train_class)
        shutil.copytree(source_dir, destination_dir)
        
    print("Finish creating train set!")
    
    return res

In [12]:
def sample_dataset(path, new_path, size):
    classes, _ = find_classes(path)
    for class_name in tqdm(classes):
        class_path = os.path.join(path, class_name)
        images = make_dataset(class_path)
        sample_images = random.sample(images, k=size)
        # print(sample_images)
        
        for image in sample_images:
            img = cv2.imread(image)
            #print(img)
            dirs = os.path.join(new_path, class_name)
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            if img is not None:
                cv2.imwrite(dirs + '/' + image.split('/')[-1], img)
            

In [13]:
def sample_dataset_new(path, new_path, size):
    images = make_dataset(path)
    sample_images = random.sample(images, k=size)
    print(sample_images)
    
    for image in tqdm(sample_images):
        img = cv2.imread(image)
        #print(img)
        dirs = new_path
        if not os.path.exists(dirs):
            os.makedirs(dirs)
        if img is not None:
            cv2.imwrite(dirs + '/' + image.split('/')[-1], img)

In [8]:
sample_dataset_new('/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid_new', 2000)

['/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Charles_Dance_00000435.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Julie_Chen_00000450.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Miranda_Hart_00000129.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Garrett_Hedlund_00000279.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Michelle_Keegan_00000329.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Jeff_Bezos_00000097.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Jason_Dohring_00000382.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Esther_Williams_00000518.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Virginia_Madsen_00000524.jpg', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid/Robert_Pattinson_00000199.jpg', '/hom

100%|██████████| 2000/2000 [00:12<00:00, 165.84it/s]


In [11]:
sample_dataset('/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/vggfaces_eva_genrated_cfg_1.3_strength_0.95_r_5.4_30_samples_6_each',
               '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/vggfaces_eva_genrated_cfg_1.3_strength_0.95_r_5.4_30_samples_6_each_128/', 128)

  0%|          | 0/572 [00:00<?, ?it/s]

100%|██████████| 572/572 [05:49<00:00,  1.64it/s]


In [15]:
classes_1, class_to_idx_1 = find_classes('/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/vggfaces_eva_genrated_cfg_1.3_strength_0.95_r_5.4_30_samples_6_each_128/')
classes_2, class_to_idx_2 = find_classes('/data2/mhf/DXL/Lingxiao/datasets/vggfaces_eva/test')
print(len(classes_1))
print(len(classes_2))
for i in range(len(classes_2)):
    if classes_2[i] not in classes_1[i]:
        print(i)
        print(classes_2[i])

522
572
400
NeNe_Leakes
401
Neal_McDonough
402
Neil_deGrasse_Tyson
403
Nelly
404
Nestor_Carbonell
405
Niall_Horan
406
Nick_Cannon
407
Nick_Wechsler
408
Nicki_Minaj
409
Nicola_Peltz
410
Nikki_Blonsky
411
Nikki_Griffin
412
Nikolaj_Coster-Waldau
413
Nina_Arianda
414
Nina_Dobrev
415
Noa_Tishby
416
Noah_Emmerich
417
Noel_Clarke
418
Nora_Arnezeder
419
Oliver_Platt
420
Olivia_Hallinan
421
Olly_Alexander
422
Omar_Epps
423
Omar_Gooding
424
Orlando_Jones
425
Paget_Brewster
426
Patrick_Dempsey
427
Patrick_Macnee
428
Patrick_Wilson
429
Paul_Adelstein
430
Paul_Michael_Glaser
431
Paul_Reubens
432
Paul_Scheer
433
Paula_LaBaredas
434
Penn_Badgley
435
Pete_Postlethwaite
436
Pete_Wentz
437
Peter_Facinelli
438
Peter_Krause
439
Pierre_Niney
440
Qi_Shu
441
Quentin_Tarantino
442
Rachel_Dratch
443
Rade_Serbedzija
444
Rafael_Nadal
445
Rainn_Wilson
446
Raoul_Bova
447
Rebecca_Creskoff
448
Rekha
449
Reshma_Shetty
450
Retta
451
Rhys_Coiro
452
Richard_Boone
453
Richard_Grieco
454
Richard_Madden
455
Richard_Schiff


IndexError: list index out of range

In [20]:
classes, _ = find_classes(
    '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/vggfaces_eva_genrated_cfg_1.3_strength_0.95_r_5.4_30_samples_6_each')
for class_name in classes:
    class_path = os.path.join(
        '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/vggfaces_eva_genrated_cfg_1.3_strength_0.95_r_5.4_30_samples_6_each', class_name)
    images = make_dataset(class_path)
    print(len(images))

180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180


330
300
276
342
306
336
336
336
324
288
318
312
342
324
318
354
336
324
324
342
318
312
288
300
318
330
324
336
324
330
300
324
336
336
318
336
342
324
342
354
348
324
342
342
324
318
336
312
324
318
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180
180


In [160]:
path = '/home/zhangyuanyuan/Lingxiao/datasets/flowers/dataset/train'
classes, class_to_idx = find_classes(path)
class_dic = {}
for one_class in classes:
    class_path = path + '/' + one_class
    images = make_dataset(class_path)
    class_dic[one_class] = len(images)

#print(class_dic)
sortedDict = {k: v for k, v in sorted(class_dic.items(), key=lambda item: -item[1])}
print(sortedDict)

res = sum(sortedDict.values()) / len(sortedDict)
print("Average number of images per class : " + str(res))
for i in range(100000000):
    random.seed(i)
    average_number = select_best_test_set('/home/zhangyuanyuan/Lingxiao/datasets/flowers/dataset/train', '/home/zhangyuanyuan/Lingxiao/datasets/flowers/dataset/train', 17, classes, sortedDict)
    if average_number > 106:
        print(i)
        print(average_number)

{'51': 206, '77': 205, '46': 157, '89': 153, '73': 147, '74': 142, '81': 135, '94': 132, '88': 116, '78': 112, '83': 104, '95': 101, '43': 100, '41': 97, '75': 95, '37': 92, '56': 92, '65': 88, '58': 86, '60': 85, '76': 83, '80': 82, '82': 82, '72': 77, '12': 73, '44': 73, '50': 73, '23': 72, '96': 72, '53': 70, '8': 70, '11': 68, '98': 68, '52': 67, '84': 66, '90': 66, '18': 65, '71': 64, '29': 62, '36': 62, '30': 61, '47': 61, '17': 60, '91': 59, '48': 57, '55': 56, '59': 56, '28': 55, '40': 54, '5': 54, '97': 54, '92': 53, '66': 51, '70': 51, '87': 51, '57': 50, '99': 50, '101': 49, '2': 49, '42': 49, '31': 48, '62': 48, '85': 48, '86': 48, '22': 47, '54': 47, '20': 46, '69': 46, '14': 44, '38': 44, '4': 44, '68': 43, '63': 42, '64': 42, '9': 41, '10': 38, '13': 38, '15': 38, '19': 38, '49': 38, '102': 36, '16': 36, '27': 36, '3': 36, '32': 36, '61': 36, '67': 36, '100': 35, '24': 35, '6': 35, '21': 34, '25': 34, '79': 34, '93': 34, '26': 33, '35': 33, '39': 33, '45': 33, '7': 33, '

KeyboardInterrupt: 

In [13]:
# save seed to split flowers dataset
# random.seed(6947447)#seeds 6947447 5439445 12162761
split_dataset('/data2/mhf/DXL/Lingxiao/datasets/flowers/dataset',
              '/data2/mhf/DXL/Lingxiao/datasets/flowers_eva_random', 17)

Average number of images per class : 72.25490196078431
Average number of images per class of selected class: 70.88235294117646


  0%|          | 0/17 [00:00<?, ?it/s]

100%|██████████| 17/17 [00:06<00:00,  2.43it/s]


Finish creating test set!


100%|██████████| 85/85 [00:35<00:00,  2.39it/s]

Finish creating train set!





In [19]:
# save seed to split vggfaces dataset
random.seed(795437)#seeds 795437 438987 1533772
average_number = select_best_test_set_animals(
    '/data2/mhf/DXL/Lingxiao/datasets/vggfaces', '/data2/mhf/DXL/Lingxiao/datasets/vggfaces_eva', 572)
print(average_number)

Average number of images per class : 155.37358326068002
Average number of images per class of selected class: 160.97202797202797


100%|██████████| 572/572 [00:11<00:00, 48.10it/s]


Finish creating test set!


100%|██████████| 1722/1722 [00:32<00:00, 52.20it/s]

Finish creating train set!
160.97202797202797





In [32]:
# save seed to split vggfaces dataset
random.seed(795437)#seeds 795437 438987 1533772
average_number = select_best_test_set_faces('/home/zhangyuanyuan/Lingxiao/datasets/vggfaces', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected', 572)
print(average_number)

Average number of images per class : 155.37358326068002
Average number of images per class of selected class: 185.53321678321677


100%|██████████| 572/572 [00:11<00:00, 50.63it/s]


Finish creating test set!


100%|██████████| 1722/1722 [00:31<00:00, 55.09it/s]

Finish creating train set!
185.53321678321677





In [20]:
# save seed to split animalfaces dataset
random.seed(2357561)#seed 2357561 9539210 5482551
average_number = select_best_test_set_animals(
    '/data2/mhf/DXL/Lingxiao/datasets/animals', '/data2/mhf/DXL/Lingxiao/datasets/animals_eva', 30)
print(average_number)

Average number of images per class : 788.48322147651
Average number of images per class of selected class: 893.8666666666667


100%|██████████| 30/30 [00:01<00:00, 16.78it/s]


Finish creating test set!


100%|██████████| 119/119 [00:05<00:00, 21.40it/s]

Finish creating train set!
893.8666666666667





In [23]:
# save seed to split nabirds dataset
random.seed(2357561)  # seed 2357561 9539210 5482551
average_number = select_best_test_set_faces(
    '/data2/mhf/DXL/Lingxiao/datasets/nabirds/images', '/data2/mhf/DXL/Lingxiao/datasets/nabirds_eva', 111)
print(average_number)

Average number of images per class : 87.4990990990991
Average number of images per class of selected class: 119.36036036036036


100%|██████████| 111/111 [00:09<00:00, 11.83it/s]


Finish creating test set!


100%|██████████| 444/444 [00:23<00:00, 18.92it/s]

Finish creating train set!
119.36036036036036





In [7]:
select_subset('/home/zhangyuanyuan/Lingxiao/datasets/animals', '/home/zhangyuanyuan/Lingxiao/datasets/animals_60', 60)

60
89
finish 0/60
finish 1/60
finish 2/60
finish 3/60
finish 4/60
finish 5/60
finish 6/60
finish 7/60
finish 8/60
finish 9/60
finish 10/60
finish 11/60
finish 12/60
finish 13/60
finish 14/60
finish 15/60
finish 16/60
finish 17/60
finish 18/60
finish 19/60
finish 20/60
finish 21/60
finish 22/60
finish 23/60
finish 24/60
finish 25/60
finish 26/60
finish 27/60
finish 28/60
finish 29/60
finish 30/60
finish 31/60
finish 32/60
finish 33/60
finish 34/60
finish 35/60
finish 36/60
finish 37/60
finish 38/60
finish 39/60
finish 40/60
finish 41/60
finish 42/60
finish 43/60
finish 44/60
finish 45/60
finish 46/60
finish 47/60
finish 48/60
finish 49/60
finish 50/60
finish 51/60
finish 52/60
finish 53/60
finish 54/60
finish 55/60
finish 56/60
finish 57/60
finish 58/60
finish 59/60
Finish creating new dataset!


In [52]:
split_dataset('/home/zhangyuanyuan/Lingxiao/datasets/flowers/dataset', '/home/zhangyuanyuan/Lingxiao/datasets/flowers_eva', 17)

{'1': 27, '10': 38, '100': 35, '101': 49, '102': 36, '11': 68, '12': 73, '13': 38, '14': 44, '15': 38, '16': 36, '17': 60, '18': 65, '19': 38, '2': 49, '20': 46, '21': 34, '22': 47, '23': 72, '24': 35, '25': 34, '26': 33, '27': 36, '28': 55, '29': 62, '3': 36, '30': 61, '31': 48, '32': 36, '33': 31, '34': 28, '35': 33, '36': 62, '37': 92, '38': 44, '39': 33, '4': 44, '40': 54, '41': 97, '42': 49, '43': 100, '44': 73, '45': 33, '46': 157, '47': 61, '48': 57, '49': 38, '5': 54, '50': 73, '51': 206, '52': 67, '53': 70, '54': 47, '55': 56, '56': 92, '57': 50, '58': 86, '59': 56, '6': 35, '60': 85, '61': 36, '62': 48, '63': 42, '64': 42, '65': 88, '66': 51, '67': 36, '68': 43, '69': 46, '7': 33, '70': 51, '71': 64, '72': 77, '73': 147, '74': 142, '75': 95, '76': 83, '77': 205, '78': 112, '79': 34, '8': 70, '80': 82, '81': 135, '82': 82, '83': 104, '84': 66, '85': 48, '86': 48, '87': 51, '88': 116, '89': 153, '9': 41, '90': 66, '91': 59, '92': 53, '93': 34, '94': 132, '95': 101, '96': 72, '9

100%|██████████| 17/17 [00:17<00:00,  1.05s/it]


Finish creating test set!


100%|██████████| 85/85 [00:35<00:00,  2.41it/s]

Finish creating train set!





In [33]:
# Use this function to split test set to S_sample and S_real
create_new_datasets('/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected/test/', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_128/', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_128_real/', 128)

100%|██████████| 128/128 [00:00<00:00, 3292.19it/s]
100%|██████████| 128/128 [00:00<00:00, 3001.44it/s]
100%|██████████| 128/128 [00:00<00:00, 2515.10it/s]
100%|██████████| 128/128 [00:00<00:00, 1651.67it/s]
100%|██████████| 128/128 [00:00<00:00, 1797.39it/s]
100%|██████████| 128/128 [00:00<00:00, 2478.96it/s]
100%|██████████| 128/128 [00:00<00:00, 3083.04it/s]
100%|██████████| 128/128 [00:00<00:00, 2851.34it/s]
100%|██████████| 128/128 [00:00<00:00, 2735.40it/s]
100%|██████████| 128/128 [00:00<00:00, 2645.43it/s]
100%|██████████| 128/128 [00:00<00:00, 2834.99it/s]
100%|██████████| 128/128 [00:00<00:00, 1380.48it/s]
100%|██████████| 128/128 [00:00<00:00, 3102.55it/s]
100%|██████████| 128/128 [00:00<00:00, 2307.91it/s]
100%|██████████| 128/128 [00:00<00:00, 3179.87it/s]
100%|██████████| 128/128 [00:00<00:00, 2963.85it/s]
100%|██████████| 128/128 [00:00<00:00, 2871.03it/s]
100%|██████████| 128/128 [00:00<00:00, 2124.39it/s]
100%|██████████| 128/128 [00:00<00:00, 2647.87it/s]
100%|███████

In [34]:
# Use this function to change format of dataset to calculate socres
copy_dataset('/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected/test', '/home/zhangyuanyuan/Lingxiao/datasets/vggfaces_eva_selected_test_fid')

106125


 29%|██▉       | 30774/106125 [00:08<00:19, 3891.85it/s]libpng error: IDAT: invalid distance too far back
100%|██████████| 106125/106125 [00:28<00:00, 3757.73it/s]


In [15]:
import os
import torch
import shutil
import random
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                         0.229, 0.224, 0.225])  # 标准化
])


def split_data(data_dir, train_size, val_size, test_size):
    classes, _ = find_classes(data_dir)
    train_images = []
    val_images = []
    test_images = []

    for class_name in tqdm(classes):
        class_path = os.path.join(data_dir, class_name)
        images = make_dataset(class_path)
        random.shuffle(images)

        train_images = images[:train_size]
        val_images = images[train_size:train_size+val_size]
        test_images = images[train_size +
                              val_size:train_size+val_size+test_size]
        
        for image in train_images:
            img = cv2.imread(image)
            #print(img)
            dirs = os.path.join(data_dir+'_30', 'train', class_name)
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            if img is not None:
                cv2.imwrite(dirs + '/' + image.split('/')[-1], img)
        for image in val_images:
            img = cv2.imread(image)
            #print(img)
            dirs = os.path.join(data_dir+'_30', 'valid', class_name)
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            if img is not None:
                cv2.imwrite(dirs + '/' + image.split('/')[-1], img)
        for image in test_images:
            img = cv2.imread(image)
            #print(img)
            dirs = os.path.join(data_dir+'_30', 'test', class_name)
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            if img is not None:
                cv2.imwrite(dirs + '/' + image.split('/')[-1], img)

    return train_images, val_images, test_images


class CustomDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = sorted(make_dataset(image_paths))
        self.classes, self.class_to_idx = find_classes(image_paths)
        self.transform = transform
        # self.folder_to_label = {folder.split('_')[0]: idx for idx, folder in enumerate(sorted(os.listdir(image_paths)))}

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        class_label = image_path.split('/')[-2]
        label = self.class_to_idx[class_label]
        return image, label


def train_model(model, criterion, optimizer, train_loader, val_loader, output_dir, device, num_epochs=10):
    best_val_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = correct / total * 100

        print(
            f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
        print(f"Validation Accuracy: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            torch.save(model.state_dict(), os.path.join(
                output_dir, 'best_model.pth'))
            print("Best model saved!")

    return model


def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = correct / total * 100
    print(f"Test Accuracy: {test_acc:.2f}%")


def load_best_model(model, model_path, device):
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    return model

def merge_dataset(path_1, path_2, new_path):
    classes_1, class_to_idx_1 = find_classes(path_1)
    classes_2, class_to_idx_2 = find_classes(path_2)
    for class_name in tqdm(classes_1):
        class_path = os.path.join(path_1, class_name)
        images = make_dataset(class_path)
        for image in images:
            img = cv2.imread(image)
            #print(img)
            dirs = os.path.join(new_path, class_name)
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            if img is not None:
                cv2.imwrite(dirs + '/' + image.split('/')[-1], img)
                
    for class_name in tqdm(classes_2):
        class_path = os.path.join(path_2, class_name)
        images = make_dataset(class_path)
        for image in images:
            img = cv2.imread(image)
            #print(img)
            dirs = os.path.join(new_path, class_name)
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            if img is not None:
                cv2.imwrite(dirs + '/' + image.split('/')[-1], img)

In [2]:
data_dir = '/data2/mhf/DXL/Lingxiao/datasets/animals_eva/test'
output_dir = '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/resnet/animal_faces/output'
pretrained_model_path = '/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/ptmodel_bs128.pth'

train_size = 30
val_size = 35
test_size = 35
batch_size = 32
learning_rate = 0.001

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda:1


In [16]:
train_images, val_images, test_images = split_data(
    data_dir, train_size, val_size, test_size)

100%|██████████| 30/30 [00:03<00:00,  9.97it/s]


In [22]:
merge_dataset('/data2/mhf/DXL/Lingxiao/Codes/HypDiffusion/outputs/animals_eva_genrated_cfg_1.3_strength_1.0_r_4.5_10_samples_6_each',
              '/data2/mhf/DXL/Lingxiao/datasets/animals_eva/test_30/train', '/data2/mhf/DXL/Lingxiao/datasets/animals_eva/test_30/train_4.5')

100%|██████████| 30/30 [00:08<00:00,  3.36it/s]
100%|██████████| 30/30 [00:00<00:00, 42.63it/s]


In [27]:
train_data_path = '/data2/mhf/DXL/Lingxiao/datasets/animals_eva/test_30/train_5.7'
val_data_path = '/data2/mhf/DXL/Lingxiao/datasets/animals_eva/test_30/valid'
test_data_path = '/data2/mhf/DXL/Lingxiao/datasets/animals_eva/test_30/test'

train_dataset = CustomDataset(train_data_path, transform)
val_dataset = CustomDataset(val_data_path, transform)
test_dataset = CustomDataset(test_data_path, transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 119)

model.load_state_dict(torch.load(pretrained_model_path))
model.fc = nn.Linear(model.fc.in_features, 30)
nn.init.kaiming_normal_(model.fc.weight)
nn.init.zeros_(model.fc.bias)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)


model = train_model(model, criterion, optimizer,
                    train_loader, val_loader, output_dir, device, num_epochs=20)

model = load_best_model(model, os.path.join(
    output_dir, 'best_model.pth'), device)
test_model(model, test_loader, device)

Epoch 1/20: 100%|██████████| 85/85 [00:10<00:00,  7.83it/s]


Train Loss: 3.0619, Train Accuracy: 18.52%
Validation Accuracy: 61.14%
Best model saved!


Epoch 2/20: 100%|██████████| 85/85 [00:10<00:00,  8.02it/s]


Train Loss: 1.5120, Train Accuracy: 61.37%
Validation Accuracy: 58.95%


Epoch 3/20: 100%|██████████| 85/85 [00:09<00:00,  8.74it/s]


Train Loss: 0.4331, Train Accuracy: 95.48%
Validation Accuracy: 58.67%


Epoch 4/20: 100%|██████████| 85/85 [00:09<00:00,  9.10it/s]


Train Loss: 0.0713, Train Accuracy: 99.96%
Validation Accuracy: 58.29%


Epoch 5/20: 100%|██████████| 85/85 [00:09<00:00,  8.99it/s]


Train Loss: 0.0230, Train Accuracy: 100.00%
Validation Accuracy: 60.10%


Epoch 6/20:  39%|███▉      | 33/85 [00:03<00:05,  8.95it/s]


KeyboardInterrupt: 