In [None]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image


from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch

import sys
sys.path.append('../')
import torch
from torchvision import utils
import matplotlib.pyplot as plt
from src.pl_module import MelanomaModel
from src.models.networks import Generator_auxGAN_512
from src.transforms.albu import get_valid_transforms_with_resize
import albumentations as A
import numpy as np
import torch.nn as nn
from tqdm.auto import tqdm
from catalyst.utils import set_global_seed
import skimage.io
import pandas as pd
from torch.autograd import Variable
import cv2
%matplotlib inline

In [None]:
cuda = True if torch.cuda.is_available() else False

In [None]:
n_classes = 2
latent_dim = 100
img_size = 512
model_img_size = 384
channels = 3
n_samples_per_class = 10
n_classes = 2

In [None]:
generator = Generator_auxGAN_512()
generator.cuda()

In [None]:
generator.load_state_dict(torch.load('../GANs_weights/generator_512_68000.pth'))

In [None]:
def load_model(model_name: str, model_type: str, weights: str):
    model = MelanomaModel.net_mapping(model_name, model_type)
    if weights.endswith('.pth'):
        model.load_state_dict(
            torch.load(weights)
        )
    elif weights.endswith('.ckpt'):
        checkpoint = torch.load(weights, map_location=lambda storage, loc: storage)
        pretrained_dict = checkpoint["state_dict"]
        model_dict = model.state_dict()
        pretrained_dict = {k[4:]: v for k, v in pretrained_dict.items() if k[4:] in model_dict}  # net.
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict)
    model.eval()
    model.cuda()
    print("Loaded model {} from checkpoint {}".format(model_name, weights))
    return model

In [None]:
model_name_list = [
    'resnest50d', 
    'resnest269e', 
    'resnest101e', 
    #'seresnext101_32x4d', 
    'tf_efficientnet_b3_ns', 
    'tf_efficientnet_b7_ns', 
    'tf_efficientnet_b5_ns']
model_type_list = ['SingleHeadMax'] * len(model_name_list)
weights_list = [
    '../weights/train_384_balancedW_resnest50d_fold0_heavyaugs_averaged_best_weights.pth',
    '../weights/07.09_train_384_balancedW_resnest269e_heavyaugs_averaged_best_weights.pth',
    '../weights/03.09_train_384_balancedW_resnest101e_fold0_heavyaugs_averaged_best_weights.pth',
    #'../weights/06.18_train_384_balancedW_seresnext101_32x4d_fold0_heavyaugs_averaged_best_weights.pth',
    '../weights/06.10_train_384_balancedW_b3_fold0_heavyaugs_averaged_best_weights.pth',
    '../weights/05.23_train_384_balancedW_b7_fold0_heavyaugs_averaged_best_weights.pth',
    '../weights/03.18_train_384_balancedW_b5_fold0_heavyaugs_averaged_best_weights.pth'
]
models = [load_model(model_name, model_type, weights) for model_name, model_type, weights in 
          zip(model_name_list, model_type_list, weights_list)]
valid_norm = get_valid_transforms_with_resize(model_img_size)

In [None]:
z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (n_classes * n_samples_per_class, latent_dim))))
# Get labels ranging from 0 to n_classes for n_samples_per_class
labels = np.expand_dims(np.array([num for _ in range(n_samples_per_class) for num in range(n_classes)]), 1)
labels = Variable(torch.cuda.LongTensor(labels))
generated_images = generator(z, labels)
generated_images = generated_images.mul(0.5).add(0.5)
generated_images = (255*generated_images).float()
generated_images = generated_images.detach().cpu().numpy().transpose(0, 2, 3, 1)
normalized_generated_images = [valid_norm(image=image)['image'] for image in generated_images]
normalized_generated_images = np.stack(normalized_generated_images)
normalized_generated_images = normalized_generated_images.transpose(0, 3, 1, 2)
normalized_generated_images = torch.from_numpy(normalized_generated_images)
with torch.no_grad():
    preds = [nn.Sigmoid()(model(normalized_generated_images.cuda())) for model in models]
    preds = torch.stack(preds)    
cls_1_pred = preds.mean(axis=0).cpu().numpy()

In [None]:
f, ax = plt.subplots(n_classes, n_samples_per_class, figsize=(19, 5))
ax = ax.flatten()
for idx in range(len(generated_images)):
    ax[idx].imshow(generated_images[idx].astype(int))
    ax[idx].set_title(f'Generated: {labels[idx].cpu().numpy()[0]}\npredicted: {cls_1_pred[idx][0]:.2f}');  
    ax[idx].set_yticklabels([])
    ax[idx].set_xticklabels([])  
#plt.tight_layout()

In [None]:
batch_size = 32
total_generate_images = 200000
n_rounds_generation = total_generate_images // batch_size

In [None]:
image_names_list = []
image_class_list = []
generated_class_list = []
output_folder = '/data/personal_folders/skolchenko/kaggle_melanoma/generated_data_v.004/'
os.makedirs(output_folder, exist_ok=True)
set_global_seed(42)
for genround in tqdm(range(n_rounds_generation)):
    z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (n_classes * batch_size // 2, latent_dim))))
    labels = np.expand_dims(np.array([num for _ in range(batch_size // 2) for num in range(n_classes)]), 1)
    labels = Variable(torch.cuda.LongTensor(labels))
    generated_images = generator(z, labels)
    generated_images = generated_images.mul(0.5).add(0.5)
    generated_images = (255*generated_images).float()
    generated_images = generated_images.detach().cpu().numpy().transpose(0, 2, 3, 1)
    normalized_generated_images = [valid_norm(image=image)['image'] for image in generated_images]
    normalized_generated_images = np.stack(normalized_generated_images)
    normalized_generated_images = normalized_generated_images.transpose(0, 3, 1, 2)
    normalized_generated_images = torch.from_numpy(normalized_generated_images)
    with torch.no_grad():
        preds = [nn.Sigmoid()(model(normalized_generated_images.cuda())) for model in models]
        preds = torch.stack(preds)    
    cls_1_pred = preds.mean(axis=0)[:, 0].cpu().numpy()
    image_names = [output_folder+f'generated_{x+genround*batch_size}.jpg' for x in range(batch_size)]
    for idx in range(batch_size):
        resized_image = cv2.resize(generated_images[idx], (384, 384)).astype(int)
        skimage.io.imsave(fname=image_names[idx], arr=resized_image.astype(np.uint8))
    image_class_list.extend(cls_1_pred)
    image_names_list.extend(image_names)
    generated_class_list.extend(labels[:, 0].cpu().numpy())
#f, ax = plt.subplots(1, 1, figsize=(20,20))
#ax.imshow(utils.make_grid(generated_images).detach().cpu().numpy().transpose(1,2,0))

In [None]:
generated_data_csv = pd.DataFrame({
    'image_name': [x.split('/')[-1].split('.')[0] for x in image_names_list], 
    'target': image_class_list, 
    'generated_target': generated_class_list})
generated_data_csv.head()

In [None]:
def get_mask_target(row, thr=0.5):
    if row['target'] < thr and row['generated_target'] == 0:
        return True
    elif row['target'] > thr and row['generated_target'] == 1:
        return True
    else:
        return False
mask_selection = generated_data_csv.apply(get_mask_target, axis=1)
generated_data_csv_cleaned = generated_data_csv.loc[mask_selection, :]
print(f'Generated {generated_data_csv.shape[0]} samples, but gonna use only {generated_data_csv_cleaned.shape[0]}')
generated_data_csv_cleaned.head()

In [None]:
plt.hist(generated_data_csv.loc[generated_data_csv['generated_target']==1, 'target'], alpha=0.4)
plt.hist(generated_data_csv.loc[generated_data_csv['generated_target']==0, 'target'], alpha=0.4)

In [None]:
plt.hist(generated_data_csv_cleaned.loc[generated_data_csv_cleaned['generated_target']==1, 'target'], alpha=0.4)
plt.hist(generated_data_csv_cleaned.loc[generated_data_csv_cleaned['generated_target']==0, 'target'], alpha=0.4)

In [None]:
generated_data_csv_cleaned.to_csv('../data/generated_data_v.004.cleaned.csv', index=False)
generated_data_csv.to_csv('../data/generated_data_v.004.csv', index=False)