In [1]:
import os
import sys
import time
import torch
import torch.nn
import argparse
from PIL import Image
from tensorboardX import SummaryWriter
import numpy as np
from validate import validate
from data import create_dataloader
from networks.trainer import Trainer
from options.train_options import TrainOptions
from options.test_options import TestOptions
from util import Logger


In [2]:
import random


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False


In [3]:
# test config
vals = ['progan', 'stylegan', 'stylegan2', 'biggan', 'cyclegan', 'stargan', 'gaugan', 'deepfake']
multiclass = [1, 1, 1, 0, 1, 0, 0, 0]

len(vals), len(multiclass)

(8, 8)

In [4]:
def get_val_opt():
    val_opt = TrainOptions().parse(print_options=True)
    val_opt.dataroot = '{}/{}/'.format(val_opt.dataroot, val_opt.val_split)
    val_opt.isTrain = False
    val_opt.no_resize = False
    val_opt.no_crop = False
    val_opt.serial_batches = True

    return val_opt

In [5]:

def testmodel():
    print('*'*25);accs = [];aps = []
    print(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()))
    for v_id, val in enumerate(vals):
        Testopt.dataroot = '{}/{}'.format(Testdataroot, val)
        Testopt.classes = os.listdir(Testopt.dataroot) if multiclass[v_id] else ['']
        Testopt.no_resize = False
        Testopt.no_crop = True
        acc, ap, _, _, _, _ = validate(model.model, Testopt)
        accs.append(acc);aps.append(ap)
        print("({} {:10}) acc: {:.1f}; ap: {:.1f}".format(v_id, val, acc*100, ap*100))
    print("({} {:10}) acc: {:.1f}; ap: {:.1f}".format(v_id+1,'Mean', np.array(accs).mean()*100, np.array(aps).mean()*100));print('*'*25) 
    print(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()))
    
# model.eval();testmodel();

In [6]:
from bitmind.real_fake_dataset import RealFakeDataset
from bitmind.real_image_dataset import RealImageDataset
#from bitmind.random_image_generator import RandomImageGenerator
import torchvision.transforms as transforms
import torch

#real_image_dataset = RealImageDataset(huggingface_datasets=['dalle-mini/open-images'])
real_image_dataset = RealImageDataset(huggingface_dataset_names=['dalle-mini/open-images'])
train_fake_image_dataset = RealImageDataset(huggingface_dataset_names=['imagefolder:../bitmind/data/images/train'])
val_fake_image_dataset = RealImageDataset(huggingface_dataset_names=['imagefolder:../bitmind/data/images/val'])
test_fake_image_dataset = RealImageDataset(huggingface_dataset_names=['imagefolder:../bitmind/data/images/test'])


Using the latest cached version of the dataset since dalle-mini/open-images couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/duys/.cache/huggingface/datasets/dalle-mini___open-images/default/0.0.0/242c4f02f66851c6a98c3866f8b0f541226dda4f (last modified on Sun May  5 17:40:52 2024).


Resolving data files:   0%|          | 0/884 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/538 [00:00<?, ?it/s]

In [7]:
MEAN = {
    "imagenet":[0.485, 0.456, 0.406],
    "clip":[0.48145466, 0.4578275, 0.40821073]
}

STD = {
    "imagenet":[0.229, 0.224, 0.225],
    "clip":[0.26862954, 0.26130258, 0.27577711]
}

def CenterCrop():
    def fn(img):
        m = min(img.size)
        return transforms.CenterCrop(m)(img)
    return fn

transform = transforms.Compose([
    CenterCrop(),
    #transforms.Lambda(lambda img: CenterCrop()(img)),
    #transforms.CenterCrop(224),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: t.expand(3, -1, -1) if t.shape[0] == 1 else t),
    #transforms.Lambda(lambda t: t.float() / 255.),
    
    #transforms.Normalize( mean=MEAN['imagenet'], std=STD['imagenet'] ),
    #transforms.Lambda(lambda i: i / 1.0)
])

train_dataset = RealFakeDataset(real_image_dataset=real_image_dataset, fake_image_dataset=train_fake_image_dataset, transforms=transform)
val_dataset = RealFakeDataset(real_image_dataset=real_image_dataset, fake_image_dataset=val_fake_image_dataset, transforms=transform)
test_dataset = RealFakeDataset(real_image_dataset=real_image_dataset, fake_image_dataset=test_fake_image_dataset, transforms=transform)


In [8]:
from argparse import ArgumentParser
opt = TrainOptions().parse()
seed_torch(100)
print(opt)
#Testdataroot = os.path.join(opt.dataroot, 'test')
#print(Testdataroot)

----------------- Options ---------------
                     arch: res50                         
               batch_size: 64                            
                    beta1: 0.9                           
                blur_prob: 0                             
                 blur_sig: 0.5                           
          checkpoints_dir: ./checkpoints                 
                class_bal: False                         
                  classes:                               
           continue_train: False                         
                 cropSize: 224                           
                 data_aug: False                         
                 dataroot: ./dataset/                    
                delr_freq: 20                            
          earlystop_epoch: 15                            
                    epoch: latest                        
              epoch_count: 1                             
                  gpu_ids: -1 

In [9]:

#opt.dataroot = '{}/{}/'.format(opt.dataroot, opt.train_split)
Logger(os.path.join(opt.checkpoints_dir, opt.name, 'log.log'))
print('  '.join(list(sys.argv)) )
val_opt = get_val_opt()
Testopt = TestOptions().parse(print_options=False)
#data_loader = create_dataloader(opt)


train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train"))
val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val"))


/Users/duys/anaconda3/envs/bmsn/lib/python3.10/site-packages/ipykernel_launcher.py  -f  /Users/duys/Library/Jupyter/runtime/kernel-09e4af40-319f-4975-8b59-77a57ae5d256.json
----------------- Options ---------------
                     arch: res50                         
               batch_size: 64                            
                    beta1: 0.9                           
                blur_prob: 0                             
                 blur_sig: 0.5                           
          checkpoints_dir: ./checkpoints                 
                class_bal: False                         
                  classes:                               
           continue_train: False                         
                 cropSize: 224                           
                 data_aug: False                         
                 dataroot: ./dataset/                    
                delr_freq: 20                            
          earlystop_epoch: 15  

In [10]:

model = Trainer(opt)



In [11]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=lambda d: tuple(d))
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=lambda d: tuple(d))
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=lambda d: tuple(d))


In [12]:

model.train()
print(f'cwd: {os.getcwd()}')
for epoch in range(opt.niter):
    epoch_start_time = time.time()
    iter_data_time = time.time()
    epoch_iter = 0

    for i, data in enumerate(train_loader):
        print(data)
        model.total_steps += 1
        epoch_iter += opt.batch_size

        model.set_input(data)
        model.optimize_parameters()

        print(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()), "Train loss: {} at step: {} lr {}".format(model.loss, model.total_steps, model.lr))
    
        if model.total_steps % opt.loss_freq == 0:
            train_writer.add_scalar('loss', model.loss, model.total_steps)

    if epoch % opt.delr_freq == 0 and epoch != 0:
        print(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()), 'changing lr at the end of epoch %d, iters %d' %
              (epoch, model.total_steps))
        model.adjust_learning_rate()
        
    # Validation
    model.eval()
    acc, ap = validate(model.model, val_loader)[:2]
    val_writer.add_scalar('accuracy', acc, model.total_steps)
    val_writer.add_scalar('ap', ap, model.total_steps)
    print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
    testmodel()
    model.train()

model.eval();testmodel()
model.save_networks('last')
