Notebook to sample images from PlaqueGAN and filter using Plaquebox CNN as selection classifier

In [1]:
import torch
import torch.nn as nn

from torchvision import transforms
from PIL import Image
import torchvision.utils as vutils
import torchvision
from torch.cuda import amp
import numpy as np

from operation import load_params, get_config

import argparse
from tqdm import tqdm

import pandas as pd
import numpy as np

from models import Generator

from models_orig import Generator as GeneratorOld
from metrics.metric_utils import Batched_Normalize
import os

import json

In [2]:
# backbone of the CNN model to load the model parameters into
class Net(nn.Module):

    def __init__(self, fc_nodes=512, num_classes=3, dropout=0.5):
        super(Net, self).__init__()
        
    def forward(self, x):
 
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x
    
# load plaquebox model
device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")
pbox_ckpt_path = '../Plaquebox/plaquebox-paper-master/models/CNN_model_parameters.pkl'
norm_stats_path = '../Plaquebox/plaquebox-paper-master/utils/normalization.npy'
norm = np.load(norm_stats_path, allow_pickle=True).item()
pbox_cnn = torch.load(pbox_ckpt_path, map_location=lambda storage, loc: storage)
pbox_cnn = list(pbox_cnn.modules())[1]
# pbox_cnn.to(device)
# _=pbox_cnn.eval()

norm_stats_path = '../Plaquebox/plaquebox-paper-master/utils/normalization.npy'
norm = np.load(norm_stats_path, allow_pickle=True).item()
# trans_cnn = transforms.Compose([transforms.ToTensor(),
#                             transforms.Normalize(mean=norm['mean'],std=norm['std'])])

  and should_run_async(code)


In [3]:
def truncated_z_sample(batch_size, z_dim, truncation = 0.5, seed = None):
    state = None if seed is None else np.random.RandomState(seed)
    if truncation > 0:
        values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state)
        return torch.as_tensor(truncation * values, dtype=torch.float32)
    else:
        return torch.randn((batch_size, z_dim))

def sample_generator(netG, z, norm_stats=None):
    # with amp.autocast():
    imgs = netG(z)[0]
    # convert images from -1 1 to 0 255 uint8 (as would be done if saving)
    imgs_convert = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    # convert back to torch FloatTensor
    imgs_convert = imgs_convert.to(torch.float32).div_(255)
    # normalize
    if norm_stats is not None:
        imgs_convert = Batched_Normalize(imgs_convert, norm_stats['mean'], norm_stats['std'])
    return imgs.add(1).mul(0.5).to('cpu'), imgs_convert

  and should_run_async(code)


In [4]:
@torch.no_grad()
def sample_gan_class_confidence(args_dict, classifier, norm, save_dir = '../Plaquebox/plaquebox-paper-master/data/tiles/train_and_val'):
    
    all_classes = [0,1,2]
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    base_dir = args_dict['base_dir']
    # acceptance threshold for desired class(es)
    thresh_true_class = args_dict['thresh_true_class']
    # rejection threshold for remaining classes - do not want these with too high confidence (but some confidence may be useful)
    thresh_false_class = args_dict['thresh_false_class']
    batch_class = args_dict['batch_class']
    batch_gen = args_dict['batch_gen']
    ema = args_dict['ema']
    
    labels_stacked = []
    file_paths = []
    total_saved = 0
    for morph, ckpt_iter, target_gen, class_idx in zip(args_dict['morph'], args_dict['ckpts_use'], args_dict['target_gen'], args_dict['class_idx']):
        
        label = [1 if i in class_idx else 0 for i in range(3)]
        not_class_idx = [i for i in all_classes if i not in class_idx]
#         ckpt_path = os.path.join(args_dict['base_dir'], f'{morph}', 'models', f'all_{ckpt_iter}.pth')
        ckpt_path = os.path.join(args_dict['base_dir'], f'{morph}_final', 'models', f'all_{ckpt_iter}.pth')
        
        checkpoint= torch.load(ckpt_path)
        
        # load in the generator
        with open(os.path.join(base_dir, f'{morph}_final', 'args.txt'), mode='r') as f:
            args_train = json.load(f)
            model_config = args_train['model_config']
            model_config = get_config('model_configs.csv', model_config, type='model')
            noise_dim = model_config['nz']

            netG = Generator(
                    nz                  = model_config['nz'],
                    activation          = model_config['g_activation'],
                    chan_attn           = model_config['g_chan_attn'],
                    sle_map             = model_config['g_skip_map'],
                    skip_conn           = model_config['g_skip_conn'],
                    spatial_attn        = model_config['g_spatial_attn'],
                    attn_layers         = model_config['g_attn_layers'],
                    conv_layers         = model_config['g_conv_layers'],
                    alternate_layers    = model_config['g_alternate_layers'],
                    anti_alias          = model_config['g_anti_alias'],
                    noise_inj           = model_config['g_noise_inj'],
                    multi_res_out       = model_config['g_multi_res_out'],
                    small_im_size       = model_config['g_small_im_size'],
                    use_tanh            = model_config['use_tanh']
            )

            print('all ok!')   
        
        # load in parameters
        if ema:
            load_params(netG, checkpoint['g_ema'])
        else:
            load_params(netG, checkpoint['g'])
        
        netG.to(device)
        classifier.to(device)
        classifier.eval()
        
        # generator warm-up
        for i in range(100):
            z = truncated_z_sample(batch_gen, noise_dim, truncation=0).to(device)
            _,_ = sample_generator(netG, z, norm_stats=norm)
        
        generated = 0
        while generated < target_gen:

            images_2_save = []
            images = []
            for i in range(batch_class // batch_gen):
                z = truncated_z_sample(batch_gen, noise_dim, truncation=0).to(device)
                imgs_save, imgs = sample_generator(netG, z, norm_stats=norm)
                images_2_save.append(imgs_save)
                images.append(imgs)

            images = torch.cat(images)
            images_2_save = torch.cat(images_2_save)
            if images.shape[1] == 1:
                images = images.repeat([1,3,1,1])
            
            # run through classifier
            with amp.autocast():
                preds = torch.sigmoid(classifier(images)).detach().cpu()

            predictions_accept = preds[:, class_idx] > thresh_true_class
            idx_accept = torch.all(predictions_accept, dim=1).nonzero()
            predictions_accept = preds[idx_accept, not_class_idx] < thresh_false_class
            idx_accept = idx_accept[torch.all(predictions_accept, dim=1).nonzero().squeeze()].squeeze()
            
            #now save the images that pass the threshold requirements
            for j in idx_accept:
                imarr = images_2_save[j].mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
                pil_img = Image.fromarray(imarr)
                pil_img.save(os.path.join(save_dir, 'gan',f'{total_saved}.jpg'), quality=95)
                labels_stacked.append(label)
                file_paths.append(f'gan/{total_saved}.jpg')
                generated +=1
                if generated%1000==0:
                    print(f'number of {morph} generated: {generated}')
                total_saved+=1
                if generated==target_gen:
                    break
            
    return file_paths, labels_stacked

  and should_run_async(code)


In [5]:
base_dir = 'D:/ucl_masters_data/project/fastGAN_experiments/extended'
morphs = ['cored','CAA','cored-diffuse','CAA-diffuse']
class_idx = [[0], [2], [0,1], [1,2]]
ckpts_use = [90000, 100000, 80000, 70000]
target_gen = [35728, 38955, 11198, 7644]

thresh_true_class = 0.7 # threshold to accept classes
thresh_false_class = 0.3 # threshold to reject other classes
args_dict = {'base_dir': base_dir, 'morph': morphs, 'ckpts_use': ckpts_use, 'target_gen': target_gen, 
             'class_idx': class_idx, 'batch_class': 32, 'batch_gen': 8, 'thresh_true_class': thresh_true_class, 
             'thresh_false_class': thresh_false_class, 'ema': True}

  and should_run_async(code)


In [7]:
file_paths, labels_stacked = sample_gan_class_confidence(args_dict, pbox_cnn, norm)

  and should_run_async(code)


all ok!
number of cored generated: 1000
number of cored generated: 2000
number of cored generated: 3000
number of cored generated: 4000
number of cored generated: 5000
number of cored generated: 6000
number of cored generated: 7000
number of cored generated: 8000
number of cored generated: 9000
number of cored generated: 10000
number of cored generated: 11000
number of cored generated: 12000
number of cored generated: 13000
number of cored generated: 14000
number of cored generated: 15000
number of cored generated: 16000
number of cored generated: 17000
number of cored generated: 18000
number of cored generated: 19000
number of cored generated: 20000
number of cored generated: 21000
number of cored generated: 22000
number of cored generated: 23000
number of cored generated: 24000
number of cored generated: 25000
number of cored generated: 26000
number of cored generated: 27000
number of cored generated: 28000
number of cored generated: 29000
number of cored generated: 30000
number of c

In [None]:
torch.cuda.synchronize()
torch.cuda.empty_cache()

In [4]:
labels_stacked = []
file_paths = []
total_saved = 0
for morph, target_gen, class_idx in zip(args_dict['morph'], args_dict['target_gen'], args_dict['class_idx']):
    label = [1 if i in class_idx else 0 for i in range(3)]
    
    generated = 0
    while generated < target_gen:
        labels_stacked.append(label)
        file_paths.append(f'gan/{total_saved}.jpg')
        generated +=1
        total_saved += 1
        

  and should_run_async(code)


In [None]:
df_out = pd.DataFrame(data=file_paths,columns=['imagename'])
df_out[['cored','diffuse','CAA']] = labels_stacked
df_out.head()

In [None]:
df_out.to_csv('./train_gann_up_3.csv',index=False)