In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms as torchvision_transforms
import numpy as np
import os
import pandas as pd
import math
import ntpath

from tqdm.notebook import tqdm as tqdm
from tqdm.auto import trange

import warnings
warnings.filterwarnings("ignore")

# from HGNN.train import CNN, dataLoader
from HGNN.train.configParser import ConfigParser, getModelName, getDatasetName
from myhelpers import config_plots
from HGNN.train import CNN, dataLoader

# config_plots.global_settings()

experimetnsFileName = "experiments.csv"

In [None]:
# Model parameters
experimentsPath="/raid/elhamod/Fish/experiments/"
dataPath="/raid/elhamod/Fish"
experimentName="Fish30-5run-PhyloNN6" #"Fish50_30-5run-BB-HGNN-crossvalidation" 
trial_hash="0e8bc6eb6edfb88c5a419e14ab0b445d72ee1945bd474a26a7abcbd4" #PhyloNN
#"9d6646b1d44b3034255f21a9d658fffe2f80e4f2180745e169abeb72" #HGNN

# image file:
# fileName= '/raid/elhamod/Fish/Curated4/Easy_50/test/Notropis nubilus/INHS_FISH_81913.jpg'
fileName= '/raid/elhamod/Fish/Curated4/Easy_30/test/Carassius auratus/INHS_FISH_4916.jpg'

# MISC
cuda=1
SEED_INT=-1

In [None]:
project="Fish_activationMaximization_hyperp2"

sweep_config = {
    'method': 'bayes',
    'early_terminate': {
       'type': 'hyperband',
       'min_iter': 8   
    }
}
metric = {
    'name': 'loss',
    'goal': 'minimize'   
    }
sweep_config['metric'] = metric

COUNT=1000


params = {
    "useRegularization":{'values': [True, False]},
    "useRandomInitImage":{'values': [True, False]},
    "L1_reg":{'values': [True, False]},
    "iterations":{
        'distribution': 'q_uniform',
        'min': 10,
        'max': 10000,
        'q':10
    },
    "weight_decay":{
        'distribution': 'log_uniform',
        'min': math.log(0.00001),
        'max': math.log(10),
    },
    "learning_rate":{
        'distribution': 'log_uniform',
        'min': math.log(0.00001),
        'max': math.log(10),
    },
}

sweep_config['parameters'] = params

In [None]:
# set cuda
if cuda is not None:
    print("using cuda", cuda)
    torch.cuda.set_device(cuda)

In [None]:
# Get experiment parameters
config_parser = ConfigParser(experimentsPath, dataPath, experimentName)
experimentsFileNameAndPath = os.path.join(experimentsPath, experimetnsFileName)
if os.path.exists(experimentsFileNameAndPath):
    experiments_df = pd.read_csv(experimentsFileNameAndPath)
else:
    raise Exception("Experiment not " + trial_hash + " found!")
experimentRecord = experiments_df[experiments_df["trialHash"] == trial_hash]
modelName = experimentRecord.iloc[0]["modelName"]
experimentPathAndName = os.path.join(experimentsPath, experimentName)
trialName = os.path.join(experimentPathAndName, modelName)
experiment_params = experimentRecord.to_dict('records')[0]
experiment_params = config_parser.fixExperimentParams(experiment_params)

if math.isnan(experiment_params['suffix']):
    experiment_params['suffix'] = None
print(experiment_params)

In [None]:
experimentPathAndName = os.path.join(experimentsPath, experimentName)
datasetManager = dataLoader.datasetManager(experimentPathAndName, dataPath)

In [None]:
datasetManager.updateParams(config_parser.fixPaths(experiment_params))

In [None]:
%%capture
train_loader, validation_loader, test_loader = datasetManager.getLoaders(SEED_INT)
# architecture = {
#     "fine": len(train_loader.dataset.csv_processor.getFineList()),
#     "coarse" : len(train_loader.dataset.csv_processor.getCoarseList())
# }
architecture = CNN.get_architecture(experiment_params, train_loader.dataset.csv_processor)
model = CNN.create_model(architecture, experiment_params, device=cuda)
CNN.loadModel(model, trialName, device=cuda)

In [None]:
def getTransformedImage(dataset, img, augmentation, normalization):
    augmentation2, normalization2, pad2 = dataset.toggle_image_loading(augmentation=augmentation, normalization=normalization)
    transforms = dataset.getTransforms()
    composedTransforms = torchvision_transforms.Compose(transforms)
    img_clone = composedTransforms(img)
#     print(img_clone.shape)
#     img_clone = img_clone.unsqueeze(0)
#     print(img_clone.shape)
    dataset.toggle_image_loading(augmentation2, normalization2, pad2)
    return img_clone

In [None]:
fig, axarr = plt.subplots(1, 2)


title = ntpath.basename(fileName)
original =  Image.open(fileName)
image_non_normalized = getTransformedImage(test_loader.dataset, original, False, False)
image_normalized = getTransformedImage(test_loader.dataset, original, False, True)

axarr[0].imshow(np.transpose(image_non_normalized.detach().numpy(), (1, 2, 0)))
axarr[1].imshow(np.transpose(image_normalized.detach().numpy(), (1, 2, 0)))
axarr[0].axis('off')
axarr[1].axis('off')

In [None]:
model.eval()
img = image_normalized.unsqueeze(0)
if cuda is not None:
    img = img.cuda()
output = model(img)
output_class = torch.argmax(output['fine'].squeeze())
im = transforms.ToPILImage()(image_non_normalized).convert("RGB")

In [None]:
import wandb

wandb.login()

In [None]:
from pytorchVisualizations.generate_class_specific_samples import ClassSpecificImageGeneration
from pytorchVisualizations.generate_regularized_class_specific_samples import RegularizedClassSpecificImageGeneration
from PIL import Image
import matplotlib.pyplot as plt

def train(config=None):
    run = wandb.init()

    experiment_params=wandb.config if config is None else config
    experiment_params = dict(experiment_params)
    row_information = {
        "experimentsPath":experimentsPath,
        "dataPath":dataPath,
        "experimentName":experimentName,
        "trial_hash":trial_hash,
        "fileName": fileName
    }
    row_information = {**row_information, **experiment_params} 
    print(row_information)

    useRandomInitImage = experiment_params["useRandomInitImage"]
    iterations = experiment_params["iterations"]
    learning_rate = experiment_params["learning_rate"]
    weight_decay = experiment_params["weight_decay"]
    L1_reg = experiment_params["L1_reg"]
    img = image_non_normalized.unsqueeze(0) if not useRandomInitImage else None
    if not experiment_params["useRegularization"]:
        csig = ClassSpecificImageGeneration(model, output_class.item(), img, cuda=cuda, normalizer=test_loader.dataset.normalizer)
        im_generated, loss  = csig.generate(iterations=iterations, initial_learning_rate=learning_rate)
    else:
        csig = RegularizedClassSpecificImageGeneration(model, output_class.item(), img, cuda=cuda, normalizer=test_loader.dataset.normalizer)
        im_generated, loss = csig.generate(iterations=iterations, initial_learning_rate=learning_rate, wd=weight_decay, L1_reg=L1_reg)
    im_generated_normalized = test_loader.dataset.normalizer(im_generated)
    
    plt_img = im_generated.squeeze().permute(1,2,0).detach().numpy()
    plt_img = (plt_img*255).astype(np.uint8)
    plt_img = Image.fromarray(plt_img)
    dist = torch.nn.Softmax()(model(im_generated_normalized.unsqueeze(0))['fine'])[0].tolist()
    dist = list(map(lambda x: [x[0], x[1]], enumerate(dist)))
    dist = wandb.Table(data=dist, columns=["class", "probability"])
    wandb.log({"loss": loss, "distribution": wandb.plot.bar(dist, 'class', 'probability'), "output": wandb.Image(plt_img)})
    
    run.finish()

In [None]:
sweep_id = wandb.sweep(sweep_config, project=project)
wandb.agent(sweep_id, function=train, count=COUNT)