## Load libraries

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os,sys
import re
import math
from datetime import datetime
import time
from collections import OrderedDict
from functools import partial
sys.dont_write_bytecode = True
from IPython.core.debugger import set_trace

In [None]:
import json
import pandas as pd
import joblib
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from pathlib import Path
from typing import Any,List, Set, Dict, Tuple, Optional, Iterable, Mapping, Union, Callable, TypeVar

from pprint import pprint
from tqdm import tqdm

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.linalg import norm as tnorm
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid

# import pytorch_lightning as pl
# from pytorch_lightning.core.lightning import LightningModule
# from pytorch_lightning import loggers as pl_loggers
# from pytorch_lightning.tuner.tuning import Tuner
# from pytorch_lightning.callbacks import Callback

# Select Visible GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
import reprlearn as rl

In [None]:
from reprlearn.visualize.utils import get_fig, show_timg, show_timgs, show_npimgs, show_batch, make_grid_from_tensors
from reprlearn.utils.misc import info, now2str, today2str, get_next_version_path, n_iter_per_epoch


## Path to data root dirs
  

In [None]:
exp_data_root = Path("/xxxx")
print("data_root: ", str(exp_data_root))
print("exists?: ", exp_data_root.exists())

## Colorbar: Visualize fft spectra of GMs; abs-diff spectra


In [None]:

fft_dir_root = exp_data_root.parent / 'GM256_fft'

# Load precomputed:
precomputed_dict_ffts_fp = '/data/xxxx/dict-avg-ffts-my-all-pass.pkl'
precomputed_dict_ffts_fp = '/data/xxxx/dict-avg-ffts-my-all-pass.pkl'
dict_avg_ffts = joblib.load(precomputed_dict_ffts_fp)


# If no precomputed: 
# compute it and pickle the dict_avg_ffts
# --- Run this --- 
# dict_avg_ffts = load_and_compute_avg_fft_all_subdirs(

#     fft_dir_root=fft_dir_root
# )
# store the dict as pickle:
# joblib.dump(dict_avg_ffts, f"./dict-avg-ffts-my-all-pass-{now2str()}.pkl")


- Show avg-spectrum of each GM

In [None]:
# show a spectrum in the dictionary of k:v = model_name: avg_fft
# todo: note each iamge are normalized independently 
# - we should probably fix a colorbar consistent over all plots..
show_npimgs(
    npimgs=list(
        map(lambda x: np.log(x + 1e-12),
            list(dict_avg_ffts.values()) )
    ),
    titles=list(dict_avg_ffts.keys())
);

In [None]:
# alternatively:
from reprlearn.visualize.utils import get_fig


# dict_arrs = dict_avg_ffts
dict_arrs = dict_abs_diff_ffts

n_arrs = len(dict_arrs)
names = np.array(list(dict_arrs.keys()))
nparrs = np.array(list(
    map(to_logscale, dict_arrs.values())
))
print(names)
print(nparrs.shape)


In [None]:
fig, axes = get_fig(n_arrs)

- Show abs-diff of spectrum from fft_real

In [None]:
# show abs-diff from fft_real
avg_fft_real = dict_avg_ffts['real-celebahq256']

# helpers
def to_logscale(arr: np.ndarray,
               eps: Optional[float]=1e-12) -> np.ndarray:
    return np.log( np.abs(arr) + 1e-12)

def compute_abs_diff(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray:  #Tuple[np.ndarray, float]:
    abs_diff = np.abs(arr1-arr2)
#     l2_dist = np.sum(abs_diff ** 2)
    return np.abs(arr1-arr2)# , l2_dist


    

In [None]:
# compute abs( fft_gm - fft_real) (abs. different of avg fft's)
# note: all frequences are there, ie not high-pass only
# precomputed_abs_diff_ffts_fp = '/data/xxx/dict-abs-diff-ffts-my-all-pass.pkl'

dict_abs_diff_ffts = {}
for model_name, avg_fft in dict_avg_ffts.items():
    dict_abs_diff_ffts[model_name] = compute_abs_diff(avg_fft, avg_fft_real)
# store the dict as pickle:
# joblib.dump(dict_abs_diff_ffts, "./dict-abs-diff-ffts-my-all-pass.pkl")



In [None]:
from IPython.core.debugger import set_trace as breakpoint

In [None]:
def show_dict_of_arrs(d_arrays: Dict[str,np.ndarray],
                      transform: Optional[Callable]=None,
                     ) -> Tuple[plt.Figure, plt.Axes]:                 
    npimgs = d_arrays.values()
    if transform is not None:
        npimgs = list(map(transform, npimgs))
        
        #debug
        for npimg in npimgs:
            print('min, max: ', npimg.min(), npimg.max())
#         breakpoint()
        
    return show_npimgs(npimgs,
                        titles=list(d_arrays.keys())
                       )
    
    

In [None]:
# show_npimgs(
#     npimgs=list(
#         map(lambda x: np.log(x + 1e-12),
#             list(dict_abs_diff_ffts.values()) )
#     ),
#     titles=list(dict_abs_diff_ffts.keys())
# );

#same: 
show_dict_of_arrs(dict_abs_diff_ffts,
                  transform=to_logscale);

In [None]:
def get_positive_min_max(d: Dict) -> Tuple[float, float]:
    vals = np.array(list(d.values()))
    vals = np.ma.masked_less_equal(vals, 0, copy=False)
    
#     print('d: ', d)
#     print('vals: ', vals)
    return vals.min(), vals.max()
                    

In [None]:
from reprlearn.visualize.utils import show_dict_with_colorbar


def plot_dict_spectra_logscale(dict_to_show, cmap=None, **kwargs):

    vmin, vmax = get_positive_min_max(dict_to_show)
    normalizer = colors.LogNorm(
        vmin=vmin,
        vmax=vmax,
        clip=False #shouldn't matter whether set to t/f
    )
    print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
    print('vmin, vmax are set? :', normalizer.scaled())


    show_dict_with_colorbar(dict_to_show, normalizer=normalizer, cmap=cmap,
                           **kwargs);

### Demo:

- Plot avgerage-spectra of each gm (allpass)  in logscale, with colorbar set consistently over all axes

1.  Plot abs-diff ffts in logscale, with colorbar set consistently over all axes

In [None]:
plot_dict_spectra_logscale(dict_abs_diff_ffts)

In [None]:
# 2. show avg-magnitude-of-spectrum for each gm
# (not abs-diffs of magnitude-spectra)
dict_avg_ffts = dict(sorted(dict_avg_ffts.items()))
plot_dict_spectra_logscale(dict_avg_ffts)

- mask around dc gain


In [None]:
left_stride = 3
right_stride = 2 #in unit of freq_u
center = int(np.ceil(256 / 2))

def set_zero_around_center(fft: np.ndarray, window_size: int=5):
    from copy import deepcopy
    fft_copy = deepcopy(fft)
    height, width = fft.shape
    center_y, center_x = int(np.ceil(height / 2)), int(np.ceil(width / 2))
    left = int(np.ceil(window_size / 2))
    right = window_size - left
               
    fft_copy[center_y - left:center_y + right, center_x-left:center_y+right] = 0.
    return fft_copy


    
    

In [None]:
# freq's around zero are remove
# plot in logscale

# show_dict_of_arrs(dict_abs_diff_ffts,
#                   transform=lambda fft: to_logscale(
#                                           set_zero_around_center(fft, window_size=3)
#                                         )
#                  );
# joblib.dump(dict_abs_diff_ffts, "./dict-abs-diff-ffts-my-all-pass.pkl")

In [None]:
import toolz

In [None]:
# apply it to a dictionary
dict_avg_ffts_no_dcgain = toolz.valmap(
    partial(set_zero_around_center, window_size=51),
    dict_avg_ffts
)

In [None]:
# show in logscale
dict_to_show = dict_avg_ffts_no_dcgain

# cmap = 'gray'
cmap=cm.rainbow
vmin, vmax = get_positive_min_max(dict_to_show)
normalizer = colors.LogNorm(
    vmin=vmin,
    vmax=vmax,
    clip=False #shouldn't matter whether set to t/f
)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())


show_dict_with_colorbar(dict_to_show, normalizer=normalizer, cmap=cmap,
                       show_colorbar_every=1);

## Now compute the abs-diff of these ffts:

- Load precomputed `dict_avg_ffts.pkl` 
- Compute `dict_abs_diff_ffts`; save as pkl file
- Visualize `dict_abs_diff_ffts` in logscale; save as png file

### Helpers

In [None]:
# load helper
def load_ffts(fft_dir_root:Path, common_fn: str):
    """Load individual GM's avg_fft into a dict of avg_ffts
    """
    d_ffts = {} 
    for model_dir in fft_dir_root.iterdir():
        if str(model_dir).startswith('.') or model_dir.is_file():
            continue
        
        model_name = model_dir.name
        d_ffts[model_name] = joblib.load(model_dir/ common_fn)
        
    return dict(sorted(d_ffts.items()))

def test_load_ffts():
    fft_dir_root = Path('/docker/data/GM256_avgfft_allpass/20230326-004747')
    common_fn = "avg-fft-allpass_n=50000_20230326171210.pkl"
    return load_ffts(fft_dir_root, common_fn)
        
                    

In [None]:
from cytoolz import valmap, keymap

In [None]:
# Helpers: compute abs-diff
def compute_abs_diff(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray:  #Tuple[np.ndarray, float]:
    abs_diff = np.abs(arr1-arr2)
#     l2_dist = np.sum(abs_diff ** 2)
    return abs_diff 

def compute_dict_abs_diff_vals(d: Dict[str,np.ndarray], 
                               key_of_ref_val: str) -> Dict[str, np.ndarray]:
    """computes abs-diff of each val in the dict w.r.t. dict[key_of_ref_val],
    and return the dictionary of abs-diff-vals
    """
    d = dict(sorted(d.items()))
    print(d.keys())
    ref_val = d[key_of_ref_val]
    return valmap(partial(compute_abs_diff, arr2=ref_val),
                          d)

In [None]:
# Helpers: plot spectrums in logscale
from reprlearn.visualize.utils import show_dict_with_colorbar

def get_positive_min_max(d: Dict) -> Tuple[float, float]:
    vals = np.array(list(d.values()))
    vals = np.ma.masked_less_equal(vals, 0, copy=False)
    
#     print('d: ', d)
#     print('vals: ', vals)
    return vals.min(), vals.max()
                    

In [None]:
def plot_dict_spectra_logscale(dict_to_show, cmap=None):
    from matplotlib import colors
    vmin, vmax = get_positive_min_max(dict_to_show)
    normalizer = colors.LogNorm(
        vmin=vmin,
        vmax=vmax,
        clip=False #shouldn't matter whether set to t/f
    )
    print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
    print('vmin, vmax are set? :', normalizer.scaled())


    show_dict_with_colorbar(dict_to_show, normalizer=normalizer, cmap=cmap);

### Load precomputed avg-ffts into a dict

In [None]:
filter_type = 'highpass' # 'allpass' 'lowpass'
ks = 3
n_samples = 50000
run_id = 20230326014156
fft_dir = Path(f'/docker/data/GM256_avgfft_{filter_type}/20230326-004747')
# fn_dict_avg_ffts_pkl = f"dict-avg-ffts-{filter_type}-ks={ks}-n={n_samples}-{run_id}.pkl"
# fn_dict_avg_ffts_png = f"dict-avg-ffts-{filter_type}-ks={ks}-n={n_samples}-{run_id}.png"

#allpass
dict_avg_ffts_ap = load_ffts(
    fft_dir_root=Path('/docker/data/GM256_avgfft_allpass/20230326-004747'),
    common_fn = "avg-fft-allpass_n=50000_20230326171210.pkl"
)

#highpass (ks=3)
# todo:
pkl_fn ='dict-avg-ffts-highpass-ks=3-n=50000-20230326014156.pkl'
dict_avg_fft_hp = joblib.load(fft_dir / pkl_fn)


#lowpass(ks=3 or ks=11)
# fn = todo
# dict_avg_ffts = joblib.load(fft_dir / fn)
# or:
dict_avg_ffts_lp = load_ffts(
    fft_dir_root=Path('/docker/data/GM256_avgfft_lowpass/20230326-004747'),
    common_fn='avg-fft-lowpass_ks=3_n=50000_20230326171058.pkl',
)

    
    


- compute abs-diff-ffts

In [None]:
#allpass
dict_abs_diff_ffts_ap = compute_dict_abs_diff_vals(dict_avg_ffts_ap, 
                                                key_of_ref_val="real-celebahq256")
#todo: write to file


#highpass
dict_abs_diff_ffts_hp = compute_dict_abs_diff_vals(dict_avg_ffts, 
                                                key_of_ref_val="real-celebahq256")
#todo: write to file

#lowpass
dict_abs_diff_ffts_lp = compute_dict_abs_diff_vals(dict_avg_ffts_lp, 
                                                key_of_ref_val="real-celebahq256")
#todo: write to file

- plot the abs-diff spectra in logscale

In [None]:
dicts = {
    'allpass': dict_abs_diff_ffts_ap,
#     'highpass': dict_abs_diff_ffts_hp,
    'lowpass': dict_abs_diff_ffts_lp
}
    
for filter_type, dict_to_show in dicts.items():
    plot_dict_spectra_logscale(dict_to_show, title=f"abs-diff: {filter_type}")

In [None]:
def to_fullname(model_name:str) -> str:
    """map model_name to model_fullname,
    e.g., stylegan2 -> gan-stylegan2
          alae -> vae-alae
    """
    namemap = {'ddgan': 'gan-ddgan',
               'stylegan2': 'gan-stylegan2',
               'styleswin': 'gan-styleswin',
               'vqgan': 'gan-vqgan',
               'celebahq256': 'real-celebahq256',
               'ddpm': 'score-ddpm',
               'ldm': 'score-ldm',
               'lsgm': 'score-lsgm',
               'ncsnpp': 'score-ncsnpp',
               'rve': 'score-rve',
               'alae': 'vae-alae',
               'effvdvae': 'vae-effvdvae',
               'nvae': 'vae-nvae',
               'vaebm': 'vae-vaebm'
              }
    return namemap.get(model_name, model_name)
               
               

In [None]:
dict_abs_diff_ffts = keymap(to_fullname, dict_abs_diff_ffts)
# sort by keynames
dict_abs_diff_ffts = dict(sorted( dict_abs_diff_ffts.items() ))
print('keys renamed: ', dict_abs_diff_ffts.keys())

In [None]:
plot_dict_spectra_logscale(dict_abs_diff_ffts)

##### Resources:
- https://stackoverflow.com/questions/73510185/how-to-add-colorbar-in-matplotlib
- https://stackoverflow.com/questions/13784201/how-to-have-one-colorbar-for-all-subplots


In [None]:
from matplotlib import cm, colors

In [None]:
# Demo1: Normalize (linear) 
# At init, vmin and vmax are not set yet:
normalizer = colors.Normalize(
    vmin=None,
    vmax=None,
    clip=False
)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())

# at first normalization call using this class, its vmin and vmax will be set
# to the min and max of the input arr __call__ was called:
arr = [-2., -1., 0., 1., 2.]
normed_arr = normalizer(arr)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())
assert normalizer.vmin == min(arr) and normalizer.vmax == max(arr)

print('arr: ', arr)
print('normed_arr: ', normed_arr)

In [None]:
# Demo2: Normalize (linear) 
# At init, vmin and vmax are specified by user.
# w/ clip=False
normalizer = colors.Normalize(
    vmin=-1,
    vmax=1,
    clip=False
)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())

# at first normalization call using this class, its vmin and vmax will be set
# to the min and max of the input arr __call__ was called:
arr = [-2., -1., 0., 1., 2.]
normed_arr = normalizer(arr)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())

print('arr: ', arr)
print('normed_arr: ', normed_arr)
print("Note that if clip=False, then the normalizer does not enforce clipping on the normed values "
      "to [0.,1.]: ie., if normed value is out of range [0,1], it just gives out those values. ")

In [None]:
# Demo3: Normalize (linear) 
# At init, vmin and vmax are specified by user.
# w/ clip=True
normalizer = colors.Normalize(
    vmin=-1,
    vmax=1,
    clip=True
)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())

# at first normalization call using this class, its vmin and vmax will be set
# to the min and max of the input arr __call__ was called:
arr = [-2., -1., 0., 1., 2.]
normed_arr = normalizer(arr)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())

print('arr: ', arr)
print('normed_arr: ', normed_arr)

In [None]:
# Demo4-a: LogNorm
# At init, vmin and vmax are not specified.
# w/ clip=False
normalizer = colors.LogNorm(
    vmin=None,
    vmax=None,
    clip=False
)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())

# at first normalization call using this class, its vmin and vmax will be set
# to the min and max of the input arr __call__ was called:
arr = [-2., -1., 0., 1., 2.]
normed_arr = normalizer(arr)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax) #0,2?
print('vmin, vmax are set? :', normalizer.scaled())

print()
print('arr: ', arr)
print('normed_arr: ', normed_arr)

In [None]:
# Demo4-b: LogNorm
# At init, vmin and vmax are not specified.
# w/ clip=False
normalizer = colors.LogNorm(
    vmin=None,
    vmax=None,
    clip=False
)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())
print('---')

# arr = [-10., -1., 0., 1, 10., 100., 1000]
arr = [-10., -1., 0., 0.001, 1, 10., 100., 1000]
# arr = [-10., -1., 0., 1, 10., 100., 1000]
normed_arr = normalizer(arr)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax) #0,1000?
print('vmin, vmax are set? :', normalizer.scaled())
print('---')

print('arr: ', arr)
print('normed_arr: ', normed_arr)

In [None]:
### Demo 5: LogNorm
# At init, vmin and vmax are not specified.
# w/ clip=False
# Note: vmin and vmax properties of a Normalize object is set either at init time
#      or at the first __call__ time. 
# In particular, any subsequent __call__ (with diff. data array to normalize)
#     does not change the vmin and vmax values.

normalizer = colors.LogNorm(
    vmin=None,
    vmax=None,
    clip=False
)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax)
print('vmin, vmax are set? :', normalizer.scaled())

# at first normalization call using this class, its vmin and vmax will be set
# to the min and max of the input arr __call__ was called:
arr = [-2., -1., 0., 1., 2.]
normed_arr = normalizer(arr)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax) #0,2?
print('vmin, vmax are set? :', normalizer.scaled())

print('arr: ', arr)
print('normed_arr: ', normed_arr)


print('===')
print("Note: Normalize object's vmin and vmax are set either at init time or at the first call. "
      "-- ie, any subsequent __call__ does not change vmin and vmax!! ")

arr = [-10., -1., 0., 10., 100., 1000]
normed_arr = normalizer(arr)
print('vmin, vmax: ', normalizer.vmin, normalizer.vmax) #0,1000?
print('vmin, vmax are set? :', normalizer.scaled())

print('arr: ', arr)
print('normed_arr: ', normed_arr)


In [None]:
arr
masked = np.ma.masked_less_equal(arr, value=0, copy=True)
print(arr,masked)

In [None]:
min(arr), min(masked), masked.min()

### use LogNorm to set a consistent mapping from data values in all npimgs 


In [None]:
# show_dict_of_arrs(dict_abs_diff_ffts,
#                   transform=lambda fft: to_logscale(
#                                           set_zero_around_center(fft, window_size=3)
#                                         )
#                  );

In [None]:
# same key:value mapping, but sorted by key value's alphabetical order
sorted_dict = dict(sorted(dict_abs_diff_ffts.items()))
# -- verify
show_dict_of_arrs(sorted_dict,
                  transform=lambda fft: to_logscale(
                                          set_zero_around_center(fft, window_size=3)
                                        )
                 );



In [None]:
# sort dictionary by its key values' alphabetic order
dict_abs_diff_ffts = dict(sorted(dict_abs_diff_ffts.items()))


In [None]:
arrs = []
model_names = []
for model_name, arr in dict_abs_diff_ffts.items():
    arrs.append(arr)
    model_names.append(model_name)
arrs = np.array(arrs)
model_names = np.array(model_names)

In [None]:
len(arrs), len(model_names)

In [None]:
vmin, vmax = arrs.min(), arrs.max()
print('vmin, vmax of abs-diff ffts: ', vmin, vmax)

In [None]:
import seaborn as sns

In [None]:
sns.histplot(arrs.flatten(), bins=50)


In [None]:
y = 10
x = arrs.flatten()
for arr, model_name in dict_abs_diff_ffts.items():
    
sns.scatterplot(x=x, y=y)


In [None]:
df_abs_diff_fft = pd.DataFrame(columns=['fam_name', 'model_name', 'abs_diff_fft'])

In [None]:
df_abs_diff_fft['model_name'] = list(dict_abs_diff_ffts.keys())
df_abs_diff_fft['abs_diff_fft'] = list(dict_abs_diff_ffts.values())
df_abs_diff_fft.head()

In [None]:
df_abs_diff_fft['fam_name'] = df_abs_diff_fft['model_name'].apply(lambda x: x.split('-')[0])

In [None]:
df_abs_diff_fft.head()


In [None]:
n_models = len(dict_abs_diff_ffts.keys())
colors = cm.rainbow(np.linspace(0,1,n_models))

fig, ax  = plt.subplots()
for arr, c in zip(dict_abs_diff_ffts.values(), colors):
    y = [10]*arr.size #number of elements in arr 
    ax.scatter(x=arr.flat,y=y, color=c, cmap=colors)

# to visualize the colormap:
# Indices to step through colormap
x = np.linspace(0.0, 1.0, 100)
fig_color, ax_color = plt.subplots()
