# Imports

In [1]:
import os, sys
sys.path.append("..")
sys.path.append("../ALAE")
import random
import string

import torch
import numpy as np

# from src.light_sb import LightSB
from ofmsrc.alae_distributions import LoaderSampler, TensorSampler
# import deeplake
from tqdm import tqdm

import wandb
from matplotlib import pyplot as plt

from alae_ffhq_inference import load_model, encode, decode

In [2]:
def seed_all(seed=123):
    OUTPUT_SEED = seed
    torch.manual_seed(OUTPUT_SEED)
    np.random.seed(OUTPUT_SEED)
    random.seed(OUTPUT_SEED)

## Config

In [3]:
DIM = 512
assert DIM > 1

INPUT_DATA = "ADULT" # MAN, WOMAN, ADULT, CHILDREN
TARGET_DATA = "CHILDREN" # MAN, WOMAN, ADULT, CHILDREN

OUTPUT_SEED = 0xBADBEEF
BATCH_SIZE = 128
D_LR = 1e-3 # 1e-3 for eps 0.1
INV_TOLERANCE = 1e-2

MAX_STEPS = 10002
CONTINUE = -1

In [4]:
seed_all(seed=OUTPUT_SEED)

In [5]:
CODE = 'MB128'
EXP_NAME = f'ofm_ALAE_{INPUT_DATA}_TO_{TARGET_DATA}_{CODE}'
OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)

config = dict(
    DIM=DIM,
    D_LR=D_LR,
    BATCH_SIZE=BATCH_SIZE,
    INV_TOLERANCE=INV_TOLERANCE,
    CODE=CODE
)

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

# Data loading

In [None]:
import gdown
import os

if not os.path.isdir('../data'):
    os.makedirs('../data')

urls = {
    "../data/age.npy": "https://drive.google.com/uc?id=1Vi6NzxCsS23GBNq48E-97Z9UuIuNaxPJ",
    "../data/gender.npy": "https://drive.google.com/uc?id=1SEdsmQGL3mOok1CPTBEfc_O1750fGRtf",
    "../data/latents.npy": "https://drive.google.com/uc?id=1ENhiTRsHtSjIjoRu1xYprcpNd8M9aVu8",
    "../data/test_images.npy": "https://drive.google.com/uc?id=1SjBWWlPjq-dxX4kxzW-Zn3iUR3po8Z0i",
}

for name, url in urls.items():
    gdown.download(url, os.path.join(f"{name}"), quiet=False)

In [6]:
# To download data use

train_size = 60000
test_size = 10000

latents = np.load("../data/latents.npy")
gender = np.load("../data/gender.npy")
age = np.load("../data/age.npy")
test_inp_images = np.load("../data/test_images.npy")

train_latents, test_latents = latents[:train_size], latents[train_size:]
train_gender, test_gender = gender[:train_size], gender[train_size:]
train_age, test_age = age[:train_size], age[train_size:]

if INPUT_DATA == "MAN":
    x_inds_train = np.arange(train_size)[(train_gender == "male").reshape(-1)]
    x_inds_test = np.arange(test_size)[(test_gender == "male").reshape(-1)]
elif INPUT_DATA == "WOMAN":
    x_inds_train = np.arange(train_size)[(train_gender == "female").reshape(-1)]
    x_inds_test = np.arange(test_size)[(test_gender == "female").reshape(-1)]
elif INPUT_DATA == "ADULT":
    x_inds_train = np.arange(train_size)[
        (train_age >= 18).reshape(-1)*(train_age != -1).reshape(-1)
    ]
    x_inds_test = np.arange(test_size)[
        (test_age >= 18).reshape(-1)*(test_age != -1).reshape(-1)
    ]
elif INPUT_DATA == "CHILDREN":
    x_inds_train = np.arange(train_size)[
        (train_age < 18).reshape(-1)*(train_age != -1).reshape(-1)
    ]
    x_inds_test = np.arange(test_size)[
        (test_age < 18).reshape(-1)*(test_age != -1).reshape(-1)
    ]
x_data_train = train_latents[x_inds_train]
x_data_test = test_latents[x_inds_test]

if TARGET_DATA == "MAN":
    y_inds_train = np.arange(train_size)[(train_gender == "male").reshape(-1)]
    y_inds_test = np.arange(test_size)[(test_gender == "male").reshape(-1)]
elif TARGET_DATA == "WOMAN":
    y_inds_train = np.arange(train_size)[(train_gender == "female").reshape(-1)]
    y_inds_test = np.arange(test_size)[(test_gender == "female").reshape(-1)]
elif TARGET_DATA == "ADULT":
    y_inds_train = np.arange(train_size)[
        (train_age >= 18).reshape(-1)*(train_age != -1).reshape(-1)
    ]
    y_inds_test = np.arange(test_size)[
        (test_age >= 18).reshape(-1)*(test_age != -1).reshape(-1)
    ]
elif TARGET_DATA == "CHILDREN":
    y_inds_train = np.arange(train_size)[
        (train_age < 18).reshape(-1)*(train_age != -1).reshape(-1)
    ]
    y_inds_test = np.arange(test_size)[
        (test_age < 18).reshape(-1)*(test_age != -1).reshape(-1)
    ]
y_data_train = train_latents[y_inds_train]
y_data_test = test_latents[y_inds_test]

X_train = torch.tensor(x_data_train)
Y_train = torch.tensor(y_data_train)

X_test = torch.tensor(x_data_test)
Y_test = torch.tensor(y_data_test)

X_sampler = TensorSampler(X_train, device="cpu")
Y_sampler = TensorSampler(Y_train, device="cpu")

# Model initialisation

## OFM

In [7]:
from ofmsrc.icnn import (
    ICNNCPF,
    ICNN2CPF,
    ICNN3CPF,
    LseICNNCPF,
    ResICNN2CPF,
    DenseICNN2CPF,
    ICNN2CPFnoSPact
)
from ofmsrc.icnn import (
    LinActnormICNN,
    DenseICNN
)

from ofmsrc.model_tools import (
    id_pretrain_model,
    ofm_forward,
    ofm_inverse,
    ofm_loss
)
from ofmsrc.model_tools import (
    TorchlbfgsInvParams,
    TorchoptimInvParams,
    BruteforceIHVParams
)

from ofmsrc.alae_distributions import NormalSampler

from ofmsrc.tools import EMA

In [8]:
class SamplerProxy:
    
    def __init__(self, sampler, device='cuda'):
        self.sampler = sampler
        self.device = device
    
    def sample(self, tpl):
        return self.sampler.sample(tpl[0]).to(self.device)

In [9]:
seed_all(seed=OUTPUT_SEED)

In [10]:
dimh, num_hidl = 256, 3
pretrain_sampler = NormalSampler(
    np.array([0.,] * DIM),
    cov=np.eye(DIM)*4.
)
D = DenseICNN(DIM, [1024, 1024]).cuda()

In [11]:
D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)

In [12]:
print(D_LR)

0.001


## ALAE

In [13]:
from ofmsrc.discrete_ot import OTPlanSampler

def get_discrete_ot_plan_sample_fn(sampler_x, sampler_y, device='cuda'):
    
    ot_plan_sampler = OTPlanSampler('exact')
    
    def ret_fn(batch_size):
        
        x_samples = sampler_x.sample(batch_size).to(device)
        y_samples = sampler_y.sample(batch_size).to(device)
        
        return ot_plan_sampler.sample_plan(x_samples, y_samples)
    
    return ret_fn

sampling_fn = get_discrete_ot_plan_sample_fn(X_sampler, Y_sampler)

In [14]:
# To download the required model run, run training_artifacts/download_all.py in the ALAE folder.

# model = load_model("../ALAE/configs/ffhq.yaml", training_artifacts_dir="../ALAE/training_artifacts/ffhq/")

In [15]:
EMA_BETAS = [0.999, 0.99]
ema = EMA(0, betas=EMA_BETAS)

In [17]:
USE_WANDB = True

In [1]:
if USE_WANDB:
    wandb.init(name=EXP_NAME, project='ofm', config=config)
    
with tqdm(range(CONTINUE + 1, MAX_STEPS)) as tbar:
    for step in tbar:
        D_opt.zero_grad()

        current_metrics = dict()
        X, Y = sampling_fn(BATCH_SIZE)
        X = X.cuda(); Y = Y.cuda()
        t = (torch.rand(BATCH_SIZE) + 1e-8).cuda()

        loss, true_loss = ofm_loss(D, X, Y, t, 
                TorchlbfgsInvParams(lbfgs_params=dict(tolerance_grad = INV_TOLERANCE), max_iter=10),
                BruteforceIHVParams(),
                tol_inverse_border = 0.5,
                stats=current_metrics)

        D_opt.zero_grad()
        loss.backward()
        D_opt.step(); D.convexify(); ema(D)
        current_metrics['loss'] = loss.item(); current_metrics['true_loss'] = true_loss.item()

        if USE_WANDB:
            wandb.log(current_metrics, step=step)

        if step % 100 == 1:
            torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, 'D.pth'))
            torch.save(D_opt.state_dict(), os.path.join(OUTPUT_PATH, f'D_opt.pt'))
            for beta in EMA_BETAS:
                beta_model_path = os.path.join(OUTPUT_PATH, 'D_{}.pth'.format(beta))
                torch.save(ema.get_model(beta).state_dict(), beta_model_path)
        tbar.set_postfix(loss=current_metrics['loss'], tloss = current_metrics['true_loss'])
#         tbar.set_postfix(loss=current_metrics['loss'], tloss = current_metrics['true_loss'], ood=current_metrics['ood_ratio'])
if USE_WANDB:
    wandb.finish()

# Results plotting

In [None]:
# To download the required model run, run training_artifacts/download_all.py in the ALAE folder.

alae_model = load_model("../ALAE/configs/ffhq.yaml", training_artifacts_dir="../ALAE/training_artifacts/ffhq/")
torch.manual_seed(OUTPUT_SEED); np.random.seed(OUTPUT_SEED)

inds_to_map = np.random.choice(np.arange((x_inds_test < 300).sum()), size=10, replace=False)
number_of_samples = 3

mapped_all = []
latent_to_map = torch.tensor(test_latents[x_inds_test[inds_to_map]])

inp_images = test_inp_images[x_inds_test[inds_to_map]]

with torch.no_grad():
    for k in range(number_of_samples):
        mapped = D.push_nograd(latent_to_map.cuda()).cpu()
        mapped_all.append(mapped)
    
mapped = torch.stack(mapped_all, dim=1)

decoded_all = []
with torch.no_grad():
    for k in range(number_of_samples):
        decoded_img = decode(alae_model, mapped[:, k].cpu())
        decoded_img = ((decoded_img * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).permute(0, 2, 3, 1).numpy()
        decoded_all.append(decoded_img)
        
decoded_all = np.stack(decoded_all, axis=1)

In [None]:

n_pictures = 2

fig, axes = plt.subplots(n_pictures, number_of_samples+1, figsize=(number_of_samples+1, n_pictures), dpi=200)

for i, ind in enumerate(range(n_pictures)):
    ax = axes[i]
    ax[0].imshow(inp_images[ind])
    for k in range(number_of_samples):
        ax[k+1].imshow(decoded_all[ind, k])
        
        ax[k+1].get_xaxis().set_visible(False)
        ax[k+1].set_yticks([])
        
    ax[0].get_xaxis().set_visible(False)
    ax[0].set_yticks([])

fig.tight_layout(pad=0.05)
fig.savefig('ofm_transfer_mb128.png')
