In [1]:
import os
import sys
import time
import torch
import torch.nn
import argparse
from PIL import Image
from tensorboardX import SummaryWriter
import numpy as np
from validate import validate
from data import create_dataloader
from networks.trainer import Trainer
from options.train_options import TrainOptions
from options.test_options import TestOptions
from util import Logger


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import random


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False


In [4]:
def get_val_opt():
    val_opt = TrainOptions().parse(print_options=True)
    val_opt.dataroot = '{}/{}/'.format(val_opt.dataroot, val_opt.val_split)
    val_opt.isTrain = False
    val_opt.no_resize = False
    val_opt.no_crop = False
    val_opt.serial_batches = True

    return val_opt

In [7]:
from bitmind.real_fake_dataset import RealFakeDataset
from bitmind.real_image_dataset import RealImageDataset
#from bitmind.random_image_generator import RandomImageGenerator
import torchvision.transforms as transforms
import torch

train_real_image_dataset = RealImageDataset(huggingface_dataset_names=['dalle-mini/open-images'], splits=['train'])
val_real_image_dataset = RealImageDataset(huggingface_dataset_names=['dalle-mini/open-images'], splits=['validation'])
test_real_image_dataset = RealImageDataset(huggingface_dataset_names=['dalle-mini/open-images'], splits=['test'])

train_fake_image_dataset = RealImageDataset(huggingface_dataset_names=['imagefolder:../bitmind/data/images/train'])
val_fake_image_dataset = RealImageDataset(huggingface_dataset_names=['imagefolder:../bitmind/data/images/val'])
test_fake_image_dataset = RealImageDataset(huggingface_dataset_names=['imagefolder:../bitmind/data/images/test'])


Downloading data: 100%|██████████| 963/963 [00:00<00:00, 203502.36files/s]
Generating train split: 963 examples [00:00, 16774.36 examples/s]
Downloading data: 100%|██████████| 115/115 [00:00<00:00, 106902.70files/s]
Generating train split: 115 examples [00:00, 12124.70 examples/s]


In [9]:
MEAN = {
    "imagenet":[0.485, 0.456, 0.406],
    "clip":[0.48145466, 0.4578275, 0.40821073]
}

STD = {
    "imagenet":[0.229, 0.224, 0.225],
    "clip":[0.26862954, 0.26130258, 0.27577711]
}

def CenterCrop():
    def fn(img):
        m = min(img.size)
        return transforms.CenterCrop(m)(img)
    return fn

transform = transforms.Compose([
    CenterCrop(),
    #transforms.CenterCrop(224),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: t.expand(3, -1, -1) if t.shape[0] == 1 else t),
    #transforms.Lambda(lambda t: t.float() / 255.),

    #transforms.Normalize( mean=MEAN['imagenet'], std=STD['imagenet'] ),
])

train_dataset = RealFakeDataset(real_image_dataset=train_real_image_dataset, fake_image_dataset=train_fake_image_dataset, transforms=transform)
val_dataset = RealFakeDataset(real_image_dataset=val_real_image_dataset, fake_image_dataset=val_fake_image_dataset, transforms=transform)
test_dataset = RealFakeDataset(real_image_dataset=test_real_image_dataset, fake_image_dataset=test_fake_image_dataset, transforms=transform)


In [11]:
from argparse import ArgumentParser
opt = TrainOptions().parse()
seed_torch(100)

Logger(os.path.join(opt.checkpoints_dir, opt.name, 'log.log'))
val_opt = get_val_opt()
Testopt = TestOptions().parse(print_options=False)

train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train"))
val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val"))
opt.earlystop_epoch = 5

----------------- Options ---------------
                     arch: res50                         
               batch_size: 64                            
                    beta1: 0.9                           
                blur_prob: 0                             
                 blur_sig: 0.5                           
          checkpoints_dir: ./checkpoints                 
                class_bal: False                         
                  classes:                               
           continue_train: False                         
                 cropSize: 224                           
                 data_aug: False                         
                 dataroot: ./dataset/                    
                delr_freq: 20                            
          earlystop_epoch: 15                            
                    epoch: latest                        
              epoch_count: 1                             
                  gpu_ids: 0  

In [12]:

model = Trainer(opt)



In [13]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=lambda d: tuple(d))
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=lambda d: tuple(d))
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=lambda d: tuple(d))


In [None]:
best_val_acc = 0

model.train()
print(f'cwd: {os.getcwd()}')
for epoch in range(opt.niter):
    epoch_start_time = time.time()
    iter_data_time = time.time()
    epoch_iter = 0

    for i, data in enumerate(train_loader):
        model.total_steps += 1
        epoch_iter += opt.batch_size

        model.set_input(data)
        model.optimize_parameters()

        print(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()), "Train loss: {} at step: {} lr {}".format(model.loss, model.total_steps, model.lr))
    
        if model.total_steps % opt.loss_freq == 0:
            train_writer.add_scalar('loss', model.loss, model.total_steps)
        
    if epoch % opt.delr_freq == 0 and epoch != 0:
        print(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()), 'changing lr at the end of epoch %d, iters %d' %
              (epoch, model.total_steps))
        model.adjust_learning_rate()

    # Validation
    model.eval()
    acc, ap = validate(model.model, val_loader)[:2]
    val_writer.add_scalar('accuracy', acc, model.total_steps)
    val_writer.add_scalar('ap', ap, model.total_steps)
    print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
    if acc > best_val_acc:
        model.save_networks('best')
        best_val_acc = acc
    #testmodel()
    model.train()

model.eval()
acc, ap = validate(model.model, test_loader)[:2]
print("(Test) acc: {}; ap: {}".format(acc, ap))
model.save_networks('last')


