# SHAPE COLOR EXPERIMENT ENTROPIC

In [None]:
import torch
import torchvision

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm
import itertools
import wandb
import os

import sys
sys.path.append("../..")
from src.utils import Config, weights_init_D, freeze, unfreeze, normalize_out_to_0_1, middle_rgb
from src.data import DatasetSampler
from src.models import ResNet_D, UNet, linear_model
from src.train_shape_color_entropic import train_shape_color_entropic
from src.cost import strong_cost, cost_image_color_latent, cost_image_shape_latent

# for generator
import sys 
sys.path.append("..")
import dnnlib
import legacy

import warnings
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

In [None]:
NUM_EPOCHS = 1
INNER_ITERATIONS = 10
BATCH_SIZE = 64
EPSILON = 0.1
LR_LATENT_MLP=1e-4
LR_ENCODER=1e-4
LR_POTENTIAL = 1e-4
GPU_DEVICE = 0
ALAE=False
RESNET_D_ENCODER = True
CHI=False
SPHERE=False

## 1. Config

In [None]:
CONFIG = Config()

CONFIG.K = 2
CONFIG.CLASSES = [2]
CONFIG.LAMBDAS = [.5,.5]

CONFIG.EPSILON = EPSILON
CONFIG.NUM_EPOCHS = NUM_EPOCHS
CONFIG.INNER_ITERATIONS = INNER_ITERATIONS
CONFIG.BATCH_SIZE = BATCH_SIZE

CONFIG.DATASET = 'colored_mnist'
CONFIG.IMG_SIZE = 32
CONFIG.NC = 1
CONFIG.DATASET_PATH = "../../../data/MNIST"


CONFIG.FLAG_LATENT = True
CONFIG.GENERATOR_PATH = "../../../stylegan2_ada_pytorch_before/ckpts/MNIST-colored_2_3/network-snapshot-002000.pkl"
CONFIG.LATENT_SIZE =512


CONFIG.LR_LATENT_MLP = LR_LATENT_MLP
CONFIG.LR_ENCODER  = LR_ENCODER
CONFIG.LR_POTENTIAL = LR_POTENTIAL
CONFIG.BETAS = (0.2, 0.99)


CONFIG.NUMBER_PALETTES = 50_000
CONFIG.HUE_MEAN = 120 # for green color: diapasone from 0 to 360
CONFIG.HUE_MEANS = [0,60,120]
CONFIG.HUE_STD = 0.
CONFIG.SATURATION = 1 # from 0 to 1
CONFIG.BRIGHTNESS = 1 # from 0 to 1
CONFIG.SATURATION_THRESHOLD = 0.8 # from 0 to 1

CONFIG.GPU_DEVICE = GPU_DEVICE
assert torch.cuda.is_available()
CONFIG.DEVICE = f'cuda:{CONFIG.GPU_DEVICE}'

## 2. Style GAN

In [None]:
if CONFIG.FLAG_LATENT:
    with dnnlib.util.open_url(CONFIG.GENERATOR_PATH) as f:
        G =  legacy.load_network_pkl(f)['G_ema'].to(CONFIG.DEVICE)

## 3. Make dataset of palettes

Here, we make rgb dataset $\mathbb{D}$:

- $\forall x \in \mathbb{D} \to x \in \mathbb{R}^{3}$
- $\forall x \in \mathbb{D} \to x[0] \in [0,1]$ is \textbf{hue}, $x[1],x[2] \in [0,1]$ are $\textbf{saturation}$ and $\textbf{brightness}$ correspondingly.

Creating:
- Set mean of $\textbf{Hue}$ to 120 as middle value for green color
- Define std and number of items for dataset
- Define $\textbf{saturation}$ and $\textbf{brightness}$ 
- Stack in one  ndarray
- Normalize $\textbf{Hue}$ from 0 to 360 - from 0 to 1 , dividing by 360.
- Use matplotlib.function for translation hsv vector to rgb
- Build palette color map

In [None]:
# initialize hue spectr for green color, satiration and brightness
hue_vectors = CONFIG.HUE_MEAN + np.random.randn(CONFIG.NUMBER_PALETTES)*CONFIG.HUE_STD # shape:( NUMBER_PALETTES, )
#hue_vectors = np.random.randint(low=0, high=131 ,size=CONFIG.NUMBER_PALETTES)
saturation_vectors = CONFIG.SATURATION*np.ones(CONFIG.NUMBER_PALETTES) # shape:( NUMBER_PALETTES, )
brightness_vectors = CONFIG.BRIGHTNESS*np.ones(CONFIG.NUMBER_PALETTES) # shape:( NUMBER_PALETTES, )

# create HSV dataset
hsv_vectors = np.stack([hue_vectors.reshape(-1,1),
                        saturation_vectors.reshape(-1,1),
                        brightness_vectors.reshape(-1,1)],axis=1).reshape(-1, 3)# shape:(NUMBER_PALETTES,3)

# translate HSV -> RGB 
# Importantly: now Hue from 0 to 360 and we translate it from 0 to 1
hsv_vectors[:,0] = hsv_vectors[:,0]/360

# we use matplotlib function : https://matplotlib.org/stable/api/_as_gen/matplotlib.colors.hsv_to_rgb.html 
rgb_dataset = matplotlib.colors.hsv_to_rgb(hsv_vectors)
assert rgb_dataset.shape == (CONFIG.NUMBER_PALETTES, 3)

In [None]:
# plot palette of RGB dataset
N = 1000
x = np.random.rand(N)
y = np.random.rand(N)
c = rgb_dataset[:1000]

plt.scatter(x, y, c=c, label="RGB dataset")
plt.legend()
plt.show()

In [None]:
color_sampler = DatasetSampler(rgb_dataset,flag_label=False,batch_size=256)

In [None]:
NUM_PLOT_GENERATED = 20
fake = G(torch.randn(NUM_PLOT_GENERATED,512).to(CONFIG.DEVICE),c=None)
clr = middle_rgb(normalize_out_to_0_1(fake, CONFIG), CONFIG.SATURATION_THRESHOLD )

# plot images from Style-GAN
fig,ax = plt.subplots(1,NUM_PLOT_GENERATED,figsize=(10,20),dpi=150)
for i in range(NUM_PLOT_GENERATED):
    ax[i].imshow(normalize_out_to_0_1(fake[i], CONFIG).permute(1,2,0).detach().cpu())
    ax[i].set_xticks([]);ax[i].set_yticks([]);
fig.tight_layout(pad=0.01)

# plot palettes
figure, axes = plt.subplots(1,20,figsize=(10,20),dpi=150) 
for i in range(NUM_PLOT_GENERATED): 
    axes[i].set_aspect( 1 ) 
    axes[i].add_artist(plt.Circle(( 0.5 , 0.5 ), 0.4 ,color=clr[i].cpu().numpy()) ) 
    axes[i].set_xticks([]);axes[i].set_yticks([]);
fig.tight_layout(pad=0.01)

## 4. Data Samplers

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((CONFIG.IMG_SIZE, CONFIG.IMG_SIZE)),
    torchvision.transforms.ToTensor()
])

data_samplers = []
 

for k in range(len(CONFIG.CLASSES)):
    dataset = torchvision.datasets.MNIST(root=CONFIG.DATASET_PATH, download=True, 
                                         transform=transform)
    idx = [t == CONFIG.CLASSES[k] for t in dataset.targets]
    dataset.targets, dataset.data = np.array(dataset.targets)[idx], torch.tensor(dataset.data)[idx] 
    data_samplers.append(DatasetSampler(dataset,flag_label=True,batch_size =256))

In [None]:
data_samplers.append(color_sampler)

## 5. Networks

### 5.1 Encoder

In [None]:
encoder = []
encoder.append( ResNet_D(size=CONFIG.IMG_SIZE,
                  nc=CONFIG.NC,
                  nfilter=64, 
                  nfilter_max=512, 
                  res_ratio=0.1,
                  n_output=2*CONFIG.LATENT_SIZE,bn_flag=True,pn_flag=True).to(CONFIG.DEVICE))
            
weights_init_D(encoder[0])

encoder.append(linear_model(3,[64,128,256],2*CONFIG.LATENT_SIZE).to(CONFIG.DEVICE))
param_enc = [net.parameters() for net in encoder]
 
#param_enc = [net.parameters() for net in  encoder]
encoder_opt = torch.optim.Adam(  itertools.chain(*param_enc),
                                  CONFIG.LR_ENCODER, betas=CONFIG.BETAS)
    
latent_mlp, latent_mlp_opt = None, None

### 5.2 Potentials

In [None]:
nets_for_pot = [ResNet_D(size=CONFIG.IMG_SIZE,
                  nc=3,
                  nfilter=64, 
                  nfilter_max=512, 
                  res_ratio=0.1,
                  n_output=1,bn_flag=False,pn_flag=False).to(CONFIG.DEVICE)
                  ]


for f in nets_for_pot: 
    weights_init_D(f)
    
nets_for_pot_opt = torch.optim.Adam( nets_for_pot[0].parameters(),
                               CONFIG.LR_POTENTIAL, betas=CONFIG.BETAS)
 

## 6.Train

In [None]:
name_exp = f"KL_EPS_{EPSILON}_LMBD_{CONFIG.LAMBDAS}_mltp_{100}_THRESHOLD_{CONFIG.SATURATION_THRESHOLD}"
CONFIG.NAME_EXP = name_exp
wandb.init(project="BNOT" ,
           name=name_exp,
           config=CONFIG)

In [None]:
train_shape_color_entropic( nets_for_pot,
                    nets_for_pot_opt,
                    encoder,
                    encoder_opt,
                    latent_mlp,
                    latent_mlp_opt,
                    data_samplers,
                    generator=G,
                    config=CONFIG
                    
                    )