# Deep, non-linear, two-factor model
A standard symmetric bilinear model in Tenanbaum2000 can be described as:
$$y^{sc}_k = \sum_{j} \sum_{i} w_{ijk}a^s_{i}b^c_{j}$$, which has an equivalent vector form:

$$\mathbf{y}^{sc} = \sum_{j} \sum_{i} \mathbf{W}_{ij}a^s_{i}b^c_{j}$$ where $\mathbf{W}_{ij}$ is a matrix of size (i,j).

This symmetric model has 2 types of model parameters:
- content variable $b$ of length $J$
- K number of matrix $W_{ij}$ of size $(I,J)$: total number of parameters of this 3Dim tensior $W$ is IxJxK.
    - Basis vector interpretation (See Eqn. 2.3): Alternative way to view this interaction weight parameter W is to view as $I \times J$ number of vectors $w_{ij}$, each of which has a length of $K$.
      This vector $w_{ij}$ specifices 
      - If we want to look at how the ith component of a style vector a^s and the jth component of a content vector b^c interacts over the entire image/data point

## Load libraries

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
# %reset out

In [None]:
import os,sys
import re
import math
from datetime import datetime
import time
sys.dont_write_bytecode = True

In [None]:
import pandas as pd
import joblib

import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
from skimage.color import rgb2gray
from skimage.transform import resize

from pprint import pprint
from pathlib import Path
from typing import List, Set, Dict, Tuple, Optional, Iterable, Mapping, Union, Callable

from ipdb import set_trace

In [None]:
# import holoviews as hv
# from holoviews import opts
# hv.extension('bokeh')

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.linalg import norm as tnorm

from torchvision import datasets, transforms
from torch.autograd import Variable

## Helpers

In [None]:
def now2str():
    now = datetime.now()
    now_str = now.strftime("%m_%d_%H:%M:%S")
    return now_str

def info(arr, header=None):
    if header is None:
        header = "="*30
    print(header)
    print("shape: ", arr.shape)
    print("dtype: ", arr.dtype)
    print("min, max: ", min(np.ravel(arr)), max(np.ravel(arr)))

In [None]:
from sklearn.preprocessing import minmax_scale
def normalize_img(img, vmin=None, vmax=None, *, use_global_range=False):
    if vmin is None:
        vmin = img.min()
    if vmax is None:
        vmax = img.max()
    shape = img.shape
    return minmax_scale(img.ravel(), feature_range=(0,1)).reshape(shape)
      

In [None]:
def visualize(data: torch.Tensor, 
              n_styles:int, n_contents:int, img_size:Tuple[int],
              *,title:str=None, normalize:bool=False) -> plt.Figure:
    """
    Visualize 2 or 3dim data matrix
    
    Args:
    
    - data (torch.Tensor)
    If data.ndim == 2, data.shape is assumed to be (n_styles*dim_x, n_contents)
    If data.ndim == 3:, data.shape is assumed to be (n_styles, n_contents, dim_x)
    
    - normalize (bool): project the values of each image by mapping the min and max to 0 and 1
       - Potentialy useful for visualization of gradients or eigenbasis
    """
    dim_x = np.prod(img_size)
    
    f, ax = plt.subplots(n_styles, n_contents, figsize=(5*n_contents, 5*n_styles))
    if title is not None:
        f.suptitle(title)
    f.tight_layout()
    for s in range(n_styles):
        for c in range(n_contents):
            if data.ndim == 2:
                img = data[s*dim_x:(s+1)*dim_x, c].reshape(img_size)
            elif data.ndim == 3:
                img = data[s,c].reshape(img_size)
            if normalize:
                img = normalize_img(img)
            ax[s][c].imshow(img)
    return f

In [None]:
def visualize_vectors(A: torch.Tensor, is_column: bool=True, title:str=None) -> plt.Figure:
    """
    Visualize each vectors in the input (2dim) tensor as a bar chart
    
    - A: 2dim tensor whose columns are individual vectors; Assumed to be detached.
    - is_column (bool): if True, assume A to be a collection of column vectors. 
        - Otherwise, A is assumed to be a collection of row vectors
    """
    if not is_column:
        A = A.T
    n_vecs = A.shape[1]
    
    f, ax = plt.subplots(nrows=1, ncols=n_vecs, figsize=(n_vecs*3, 3))
    if title is not None:
        f.suptitle(title)
    f.tight_layout()
    ax = ax.flatten()
    for i in range(n_vecs):
        vec = A[:,i]
        ax[i].bar(range(len(vec)), vec, label=f'{i}')
        ax[i].set_title(f'Vector {i+1}')
    return f


In [None]:
# Training set specific
n_styles = 3
n_contents = 9


# Hyperparams
img_size = (64,64,3) #(img_h, img_w, n_channels)
dim_content = 4 # J
dim_style = 3 # I
dim_x = np.prod(img_size)# K

# Initialize model parameters
# contents = torch.randn((n_contents, 1, dim_content)) # B: each row is a content vector
# styles = torch.randn((n_styles, 1, dim_style)) # A: each row is a style vector
# W = torch.randn((dim_x, dim_style, dim_content))

# Version2: vectorized implmentation for multiple content vectors and multiple style vectors
# -- See `00_matmul_broadcasting.ipynb` for details on how to set the shape of the tensors below
# -- for correct vectorized implementation of "generative" process 
# A (a tensor of all style vectors): (S,1,1,I)
# W (a tensor of all bilinear weights invariant to content, style classes): (K,I,J)
# B (a tensor of all content vectors): (C, 1,1, J,1)

# out = A.matmul(W) # (S x (K,1,J))
# out2 = out.matmul(B) # (C x   (S x (K,1,1))  )

styles = torch.randn((n_styles, 1,1, dim_style)) # A: each row is a style vector
W = torch.randn((dim_x, dim_style, dim_content))
contents = torch.randn((n_contents, 1,1, dim_content,1)) # B: each column is a content vector

class TFModel(nn.Module):
    "Two-factor model implemented as a stack of non-linear functions via DNN"
    def __init__(self, styles, contents, W, 
                 n_layers=1, non_linear: Callable=nn.Identity()):
        super().__init__()
        self.n_layers = n_layers
        self.non_linear = non_linear
        
        self.styles = nn.Parameter(styles) #(S,1,1,I)
        self.contents = nn.Parameter(contents) #(C, 1,1,1, J)
        self.W = nn.Parameter(W) #(K,I,J)

        self.n_styles, _,_, self.I = styles.shape
        self.n_contents, _,_, self.J,_ = contents.shape
        self.K = W.shape[0]
        
        self.cache = {} # cache to store previous iteration's values (Eg. parameters)
        
    def forward(self, *, 
                s=None, c=None):
        """
        s: style label; must be in {0,...,n_styles}
        c: content label; must be in {0,..., n_contents}
        """
#         assert self.styles[s].shape == (1,self.I)
#         assert self.contents[c].shape == (1,self.J)
        A = self.styles
        B = self.contents
        if s is not None:
            A = self.styles[[s]]
        if c is not None:
            B = self.contents[[c]]
        out = A.matmul(self.W)
#         print(out.shape)
        out = out.matmul(B)
#         print(out.shape) #(C,S,K,1,1)
        
        # By convention, output tensor has size of (S,C,K)
        # Apply sigmoid nonlinear functionn 
        # -- We choose Sigmoid because the target tensor of images will be scaled to [0,1]
        out = self.non_linear(out.permute(1,0,2,-2,-1).squeeze())
        
        # todo: more layers
        return out
    
    def shortname(self):
        return f"bilinearx{self.n_layers}_{self.non_linear}"
    
    def descr(self):
        return f"{self.shortname()}_S:{n_styles}_I:{self.I}_C:{n_contents}_J:{self.J}_K:{self.K}"
        
    def cache_params(self):
        with torch.no_grad():
            for name, param in self.named_parameters():
                self.cache[name] = param.detach().clone()
            
    def some_params_not_changed(self) -> bool:
        with torch.no_grad():
            not_changed = {}
            for name, param in self.named_parameters():
                if torch.equal(self.cache[name], param):
                    d = self.cache[name] - param
                    not_changed[name] = torch.linalg.norm(d)
                    print(tnorm(self.cache[name]), tnorm(param))
#                     print(tnorm(param.grad))
            if len(not_changed) < 1:
                return False
            else:
                print(f"Not changed: \n", not_changed)
                return True
    
    def all_params_changed(self) -> bool:
        return not self.some_params_not_changed()
                
    def show_params(self):
        with torch.no_grad():
            visualize_vectors(self.styles.squeeze(), is_column=False, title='Styles');
            visualize_vectors(self.contents.squeeze(), is_column=False, title='Contents');
            visualize(self.W.permute(1,2,0), self.I, self.J, img_size, title='W', normalize=True);
        
    def show_grads(self):
        with torch.no_grad():
            visualize_vectors(self.styles.grad.squeeze(), is_column=False, title='Styles');
            visualize_vectors(self.contents.grad.squeeze(), is_column=False, title='Contents');
            visualize(self.W.grad.permute(1,2,0), self.I, self.J, img_size, title='W', normalize=True);




In [None]:
def to_3dim(X: torch.Tensor, target_size: Tuple[int,int,int], dtype=torch.float32)->torch.Tensor:
    """
    Rearragne data matrix X of size (n_styles*dim_x, n_contents) 
    to (n_styles, n_contents, dim_x)
    
    Args: 
    - X: torch.Tensor of 2dim data matrix
    - target_size: tuple of n_style, n_contents, dim_x
    """
    assert X.ndim == 2
    n_styles, n_contents, dim_x = target_size
    assert X.shape[0] == n_styles * dim_x
    assert X.shape[1] == n_contents

    target = torch.zeros(target_size, dtype=X.dtype)
    
    for s in range(n_styles):
        for c in range(n_contents):
            img = X[s*dim_x: (s+1)*dim_x, c]
            target[s,c] = img
    return target.to(dtype)
    
        
def mse(out, target):
    """
    Return a 
    out: a minibatch of reconstructed images: (S,C,K)
    target: a minibatch of ground-truth images: (S,C,K)
    """
    assert out.shape == target.shape
    n_styles, n_contents, dim_x = out.shape
    n_samples = n_stlyes * n_contents
    return nn.MSELoss()

loss_fn = nn.MSELoss()   
    

In [None]:
styles.shape,contents.shape


In [None]:
contents.shape[-2:] == (dim_content,1), styles.shape[-2:] == (1,dim_style)

In [None]:
model = TFModel(styles, contents, W)

In [None]:
for name, p in model.named_parameters():
    print(f"{name}: {p.shape}")

In [None]:
model(0,0).shape

In [None]:
out = model()
out.shape

## Restore data matrix variable X as saved from the notebook "02"


In [None]:
%store -r X
%store -r TARGET_SIZE

In [None]:
# Test create_target
def test_create_target():
    pass

# 3 styles, 9 contents, x_dim = np.prod(TARGET_SIZE), TARGET_SIZE = (64,64,3) 
sx, n_contents = X.shape
dim_x = np.prod(TARGET_SIZE)
img_size = TARGET_SIZE
n_styles = int(sx/dim_x)
print(X.shape)
print("n_styles, n_contents, dim_x: ", n_styles, n_contents, dim_x)

In [None]:
X_3d = to_3dim(X, (n_styles, n_contents, dim_x) )
X_3d.shape

In [None]:
# visualize(X, n_styles, n_contents, img_size);
# visualize(X_3d, n_styles, n_contents, img_size);
visualize(out.detach(), n_styles, n_contents, img_size, 
          normalize=True);

## Compiled training specs

In [None]:
def mkdir(p: Path, parents=True):
    if not p.exists():
        p.mkdir(parents=parents)
        print("Created: ", p)


In [None]:
def create_exp_name(hyperparams):
    pass

# Hyperparameters
n_styles, dim_style = 3, 3
n_contents, dim_content = 9, 4
img_size = (64,64,3)
dim_x = np.prod(img_size)

# Define model
styles = torch.randn((n_styles, 1,1, dim_style)) # A: each row is a style vector
W = torch.randn((dim_x, dim_style, dim_content))
contents = torch.randn((n_contents, 1,1, dim_content,1)) # B: each column is a content vector

model = TFModel(styles, contents, W)
# model.show_params()

# Gradient computation
## learn_rate depending on the type of reduction on computing the MSELoss
lrs = {'mean': 1e-2,
      'sum': 1e-6}


# Specify loss function and learning rate
reduction = 'mean'
lr = lrs[reduction]
lr_W = lr*30
# Optimizer
optim_params = [
    {'params': [model.styles, model.contents]},
    {'params': [model.W], 'lr': lr_W}
]
optimizer = optim.Adam(optim_params, lr=lr)


# Training configs
max_epoches = 100
print_every = 10
show_every = 30

# data
target = to_3dim(X, (n_styles, n_contents, dim_x))

# Start training
start = time.time()
losses = []
for ep in range(max_epoches):
    # Compute loss, and compute partial derivatives wrt each parameters, which will be stored 
    # in each parameter (tensor)'s `.grad` property
    out = model()
    loss = nn.MSELoss(reduction=reduction)(out, target) #per-dim of x (pixel)
    
    # Make sure all the `.grad`s of the model parameters are zero 
    optimizer.zero_grad()
    loss.backward()
    losses.append(loss.item())

    
    # Check if the parameters are changing before/after the gradient step
    model.cache_params()
    # Update the parameter values using the current partial derivatives based on the current loss
    optimizer.step()
    model.all_params_changed()
   
#     set_trace()
    
    # Log
    with torch.no_grad():
        if (ep+1)%print_every == 0:
            print(f"Ep {ep}: {loss.item()}")
            for n,p in model.named_parameters():
                print(n)
                print('\t', tnorm(p), tnorm(p.grad))
        if (ep+1)%show_every == 0:
            model.show_params()
            visualize(out, n_styles, n_contents, img_size, normalize=True);
print(f"Took {time.time() - start} sec. Loss: {losses[-1]}")

In [None]:
# Experiment name
result_dir = Path("../results/batch_bilinear/{model.descr()}")
mkdir(result_dir)
exp_descr = f"reduction:{reduction}_lr:{lr}_lrW:{lr_W}_ep:{ep}"

# save model parameters
# save last reconstructions
f_params = model.show_params()
f_params.savefig(result_dir/f"params_{exp_descr}")
with torch.no_grad():
    out = model()
    f_out = visualize(out, n_styles, n_contents, img_size, normalize=True);
    f_out.savefig(result_dir/f"xhat_{exp_descr}")

In [None]:
plt.plot(losses)
