In [None]:
import sys
sys.path.append('../')
from src.GANs import WGAN_GP
import torch
from torchvision import utils
import matplotlib.pyplot as plt
from src.pl_module import MelanomaModel
import albumentations as A
import numpy as np
import torch.nn as nn
from tqdm.auto import tqdm
from catalyst.utils import set_global_seed
import skimage.io
import pandas as pd
%matplotlib inline

This is and example of unconditional GAN

In [None]:
model = WGAN_GP()

In [None]:
def load_model(model_name: str, model_type: str, weights: str):
    model = MelanomaModel.net_mapping(model_name, model_type)
    model.load_state_dict(
        torch.load(weights)
    )
    model.eval()
    model.cuda()
    print("Loaded model {} from checkpoint {}".format(model_name, weights))
    return model

def get_valid_transforms():
    return A.Compose(
        [
            A.Normalize()
        ],
        p=1.0)
model_name_list = ['resnest26d'] * 5
model_type_list = ['SingleHeadMax'] * 5
weights_list = [f'../weights/resnest26d_128x128_fold{x}.pth' for x in range(5)]
models = [load_model(model_name, model_type, weights) for model_name, model_type, weights in 
          zip(model_name_list, model_type_list, weights_list)]
valid_norm = get_valid_transforms()

In [None]:
model.G.load_state_dict(torch.load('../GANs_weights/generator_2550.pkl'))

In [None]:
batch_size = 256
total_generate_images = 500000
n_rounds_generation = total_generate_images // batch_size

In [None]:
image_names_list = []
image_class_list = []
output_folder = '/data/personal_folders/skolchenko/kaggle_melanoma/generated_data_v.001/'
set_global_seed(42)
for genround in tqdm(range(n_rounds_generation)):
    z = model.get_torch_variable(torch.randn(batch_size, 100, 1, 1))
    generated_images = model.G(z)
    generated_images = generated_images.mul(0.5).add(0.5)
    generated_images = (255*generated_images).int()
    generated_images = generated_images.detach().cpu().numpy().transpose(0, 2, 3, 1)
    normalized_generated_images = [valid_norm(image=image)['image'] for image in generated_images]
    normalized_generated_images = np.stack(normalized_generated_images)
    normalized_generated_images = normalized_generated_images.transpose(0, 3, 1, 2)
    normalized_generated_images = torch.from_numpy(normalized_generated_images)
    with torch.no_grad():
        preds = [nn.Sigmoid()(model(normalized_generated_images.cuda())) for model in models]
        preds = torch.stack(preds)    
    cls_1_pred = preds.mean(axis=0)[:, 1].cpu().numpy()
    image_names = [output_folder+f'generated_{x+genround*batch_size}.jpg' for x in range(batch_size)]
    for idx in range(batch_size):
        skimage.io.imsave(fname=image_names[idx], arr=generated_images[idx].astype(np.uint8))
    image_class_list.extend(cls_1_pred)
    image_names_list.extend(image_names)
#f, ax = plt.subplots(1, 1, figsize=(20,20))
#ax.imshow(utils.make_grid(generated_images).detach().cpu().numpy().transpose(1,2,0))

Now make data frame with generated data

In [None]:
generated_data_csv = pd.DataFrame({
    'image_name': ['generated_' + x.split('/')[-1].split('.')[0] for x in image_names_list], 
    'target': image_class_list})
generated_data_csv.head()

In [None]:
generated_data_csv.to_csv('../data/generated_data_v.001.csv', index=False)