# Conditional Real-NVPs

**Goal:** Extend the Real-NVP model we built in `Real-NVP-tutorial` to make it conditional, i.e, a model for $p_\theta(x|y)$.

![](flow-graphic.png)


In [None]:
from sklearn.datasets import make_moons

import matplotlib.pyplot as plt


import math
import numpy as np
import torch
import torch.nn as nn

## First, some preliminaries copied over from the other nb

**Dataset**

In [None]:
nsamples = 30_000
noise = 0.05
X = make_moons(nsamples, noise=noise)[0]

X_torch = torch.Tensor(X).float()

**Hyperparameters and settings from the previous noteboook**

In [None]:
num_blocks=9
num_hidden=64

In [None]:
# Check for a GPU
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device='cpu'
    
print(device)

In [None]:
batch_size=256

nTrain=24_000
nVal=3_000

In [None]:
kwargs = {'num_workers': 4, 'pin_memory': True}  if torch.cuda.is_available() else {}

**Extended functionality to draw the model results with a conditional input**

In [None]:
def draw_model(model, cond_input=None, title=''):
    '''
    Goal: Given the given model weights, show
    (1) The density
    (2) Samples from p_X(x)
    
    Inputs:
    - model that we're plotting the density of
    - cond_input: Extra logic for the conditional gen model
                  (bonus prob)
    - title: Super title over 2 subfigs
    '''
    
    fig, (ax1,ax2) = plt.subplots(1,2,figsize=(10,4),
                                  gridspec_kw={'hspace':10})

    # Title the plot with the log(p) on the validation set
    if title:
        fig.suptitle(title)

    '''
    (1) Plot the density
    '''
    x = np.linspace(-1,2)
    y = np.linspace(-.75,1.25)

    xx,yy = np.meshgrid(x,y)

    X_grid = np.vstack([xx.flatten(),yy.flatten()]).T.astype(np.float32)
    X_grid.T

    X_grid = torch.tensor(X_grid).to(device)

    if cond_input:
        # For last (bonus) prob
        y0,y1 = cond_input
        Y_grid = torch.ones_like(X_grid).to(device)
        Y_grid[:,0] = y0
        Y_grid[:,1] = y1
    
        with torch.no_grad():
            log_probs = model.log_probs(X_grid,Y_grid).cpu().numpy()

    else:
        with torch.no_grad():
            log_probs = model.log_probs(X_grid).cpu().numpy()

    ax1.pcolormesh(xx,yy,np.exp(log_probs.reshape(50,50)),shading='auto',cmap='coolwarm')

    ax1.set_xlabel('$X_0$',fontsize=12)
    ax1.set_ylabel('$X_1$',fontsize=12)

    '''
    (2) Plot samples from the model
    '''
    with torch.no_grad():
        
        if cond_input:
            # For last (bonus) prob
            Y_con = torch.ones(500,2).to(device)
            Y_con[:,0] = y0
            Y_con[:,1] = y1
            X_gen = model.sample(Y_con).cpu().numpy()
        else:
            X_gen = model.sample(500).cpu().numpy()
        
    ax2.scatter(*X_gen.T)

    ax2.set_xlabel('$X_0$',fontsize=12)
    ax2.set_ylabel('$X_1$',fontsize=12)

    ax2.set_xlim(x[[0,-1]])
    ax2.set_ylim(y[[0,-1]])

    plt.show()


# Bonus: Conditional flow

Lots of applications in science involve conditional flow models, can you extend our the model we built here to be conditioned on the center point for the moon? 

**Plan:** 
- Train sampling moon density centers $y \in \mathbb{R}^2$ **uniformly** from [0,1]

Recall, we were training the 30k training samples for modelling 2 dimensions.

In [None]:
num_inputs = X.shape[1]

Y = torch.Tensor(nsamples, num_inputs).uniform_()

In [None]:
X_cond = X_torch + Y

**Let's look at the prediction for some slices of the conditional output**
- $0 < y_0, y_1 < .05$ 
- $0.5 < y_0, y_1 < .55$ 
- $0.95 < y_0, y_1 < 1$ 

In [None]:
dy = .05
for i,y_min in enumerate([0,.5,1-dy]):
    
    y_max = y_min+dy
    mi = (Y[:,0] > y_min) & (Y[:,0] < y_max)
    mi = mi & (Y[:,1] > y_min) & (Y[:,1] < y_max)
    
    y_avg = .5 * (y_min + y_max)
    
    c=f'C{i}'
    plt.scatter(*X_cond[mi].T.numpy(),alpha=.5,color=c)
    plt.scatter([y_avg],[y_avg],250,marker='x',label='avg center',color=c)
    plt.xlabel('$X_0$',fontsize=18)
    plt.ylabel('$X_1$',fontsize=18)
    plt.legend()
    plt.title(f'$y_0,y_1$ center in ({y_min},{y_max})')
    
    plt.xlim(-1.2,3.2)
    plt.ylim(-.7,2.2)
    plt.show()
    

### TO DO: Implement the conditional flow model

**Hint 1:** Extend the coupling layer to a conditional coupling layer

In [None]:
class CondCouplingLayer(nn.Module):

    def __init__(self, num_inputs, num_cond_inputs, num_hidden, mask):
        super(CondCouplingLayer, self).__init__()

        '''
        TO DO: Fill in 
        (Tip: look at `Coupling layer` for inspiration)
        '''

    def forward(self, inputs, cond_inputs,  mode='forward'):
        
        mask = self.mask
        masked_inputs = inputs * mask
        
        all_inputs = torch.cat([masked_inputs, cond_inputs],axis=1)
        
        if mode == 'forward':
            
            '''
            TO DO : Fill in 
            '''
            raise NotImplementedError  
                
        else:
            
            '''
            TO DO : Fill in 
            '''
            raise NotImplementedError  

**Hint 2:** Implement a `CondFlowSequential` class that calls the `CondCouplingLayer` class.

In [None]:
class CondFlowSequential(nn.Sequential):
    """ A sequential container for flows extending the 
    FlowSequential class for conditional inputs
    """

    def forward(self, inputs, cond_inputs, mode='forward'):
        """ Performs a forward or reverse pass for flow modules.
        Args:
            inputs: a tuple of inputs and logdets
            cond_inputs: The conditional inputs to the flow
            mode: to run direct computation or inverse
        """
        raise NotImplementedError

    def log_probs(self, inputs, cond_inputs):
        raise NotImplementedError
        
    def sample(self, cond_inputs, num_samples=None):
        
        raise NotImplementedError

**Put the pieces together to define a flow**

In [None]:
# Extending the code that we had before
num_cond_inputs = Y.shape[1]

mask = torch.arange(0, num_cond_inputs) % 2
mask = mask.to(device).float()

modules=[]
for _ in range(num_blocks):
    modules.append( CondCouplingLayer( num_inputs, num_cond_inputs, num_hidden, mask) )
    mask = 1 - mask

In [None]:
# To DO: Define the conditional flow from these modules
cond_flow = 

In [None]:
# Once you've defined the model... put it on the gpu
cond_flow = cond_flow.to(device)

**Training code**

It's the same as before, just evaluating the conditional generative model, so we'll just give you the training functions.

We'll just use a `from torch.utils.data.Dataset` first to pass to the `DataLoaders` that we'll create :)

In [None]:
'''
New dataset class to deal w/ the tuple of X,Y input
'''

from torch.utils.data import Dataset

class CondMoonsDataset(Dataset):
    '''
    Skeleton class taken from:
    https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
    '''
    def __init__(self,nExamples,noise=0.05):
        
        X_np = make_moons(nExamples, noise=noise)[0]
        X = torch.Tensor(X_np).float()
            
        num_inputs = X.shape[1]
        Y = torch.Tensor(nExamples, num_inputs).uniform_()
         
        X = X+Y
            
        self.X = X
        self.Y = Y
        
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

In [None]:
cond_train_loader = torch.utils.data.DataLoader(
    CondMoonsDataset(nTrain), 
    batch_size=batch_size, shuffle=True,**kwargs)

cond_valid_loader = torch.utils.data.DataLoader(
    CondMoonsDataset(nVal),
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    **kwargs)

test_loader = torch.utils.data.DataLoader(
    CondMoonsDataset(nVal),
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    **kwargs)

In [None]:
def cond_train(model, train_loader, opt):
    
    model.train()
    train_loss = 0

    for batch_idx, (X,Y) in enumerate(train_loader):
        
        X = X.to(device) 
        Y = Y.to(device) 
        opt.zero_grad()
        
        loss = -model.log_probs(X,Y).mean()
        train_loss += loss.item()
        
        loss.backward()
        opt.step()
        
    train_loss /= len(train_loader)
    return train_loss

def cond_validate(model, loader, prefix='Validation'):
    
    model.eval()
    val_loss = 0

    for batch_idx, (X,Y) in enumerate(loader):

        X = X.to(device) 
        Y = Y.to(device) 
            
        with torch.no_grad():
            val_loss += -model.log_probs(X,Y).sum().item()  # sum up batch loss
     
    val_loss /= len(loader.dataset)
    return val_loss 

In [None]:
lr=1e-4
opt = torch.optim.Adam(cond_flow.parameters(), lr=lr, weight_decay=1e-6)

epochs=81

cond_train_losses = np.zeros(epochs)
cond_val_losses = np.zeros(epochs)

for i in range(epochs):
    
    cond_train_losses[i] = cond_train(cond_flow, cond_train_loader, opt)
    cond_val_losses[i] = cond_validate(cond_flow, cond_valid_loader)

    print(f'Epoch {i}: train loss = {cond_train_losses[i]:.4f}, val loss = {cond_val_losses[i]:.4f}')
        
    if i % 10 == 0:
        # Draw the model prediction
        draw_model(cond_flow,cond_input=[0,0],
                   title=f'Validation loss = {cond_val_losses[i]}')
        

In [None]:
# TO DO: Plot the losses

**Sanity check:** What have we learned about crescent moons in the three cases that we had above?

In [None]:
fig, axs = plt.subplots(1,3,figsize=(12,3),
                            gridspec_kw={'wspace':.3})

fig.suptitle('Conditional flow')

# Define the density grid
x = np.linspace(-1.2,3.2)
y = np.linspace(-.7,2.2)

xx,yy = np.meshgrid(x,y)

X_grid = np.vstack([xx.flatten(),yy.flatten()]).T.astype(np.float32)
X_grid.T

X_grid = torch.tensor(X_grid).to(device)

'''
Loop over the centers
'''
for yi,ax,cmap in zip([0, 0.5, 1],axs,['Blues','Oranges','Greens']):
    
    # Plot the density
    
    Y_grid = torch.ones_like(X_grid).to(device)
    Y_grid[:,0] *= yi
    Y_grid[:,1] *= yi
    
    with torch.no_grad():
        log_probs = cond_flow.log_probs(X_grid,Y_grid).cpu().numpy()

    ax.pcolormesh(xx,yy,np.exp(log_probs.reshape(50,50)),shading='auto',cmap=cmap)

    ax.set_xlabel('$X_0$',fontsize=12)
    ax.set_ylabel('$X_1$',fontsize=12)

    ax.scatter([yi],[yi],300,marker='x',color='k')
    # break
plt.show()


**Resources:**
- This tutorial for the code in this repo came from the [pytorch-flows](https://github.com/ikostrikov/pytorch-flows) repo.
- The [nflows](https://github.com/bayesiains/nflows.git) is also a very nice package that includes the Real-NVP model and also the RQ-NSF that we also talked about in the lecture.