# GFN for music generation
code reference: https://github.com/Tikquuss/GflowNets_Tutorial

Reward function is of the form
$$R(x\mid y)=e^{-a*d(x,y)}$$
where $x$ is a source image, $y$ is constructed by the GFN, and $d(x,y)$ denotes Eucledian distance. We start with a blank, black image (all zeros) and at each step the GFN chooses a pixel and colors it 1 or returns the STOP action.

At each state, the actions are the set of all remaining coordinate pairs $[1,N]\times[1,M]$ or the STOP action.

In [None]:
import torch
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl

In [None]:
from preprocessing import loadData
songLen = 500

data = loadData(songLen)

In [None]:
import config

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

M,N = songLen, config.MIDI_NOTE_RANGE

In [None]:
def make_mlp(l, act=torch.nn.LeakyReLU(), tail=[]):
    return torch.nn.Sequential(*(sum(
        [[torch.nn.Linear(i, o)] + ([act] if n < len(l)-2 else [])
         for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))

In [None]:
# data import here

In [None]:
n_hid = 256
n_layers = 2
ndim = M*N
input_dim = ndim # embedding dim
output_dim = 2*ndim+1 # ndim + 1 for P_F (+1 for stop action) and ndim for P_B 

#For GAN, it is better to learn a Z for each image and not a global Z as there.
independent_Z = False
if independent_Z:
    output_dim+=1
    logZ_TB = 0
    model_TB = make_mlp([input_dim] + [n_hid] * n_layers + [output_dim]).to(device)
    optimizer = torch.optim.Adam([ {'params':model_TB.parameters(), 'lr':0.001}])
else :
    logZ_TB = torch.zeros((1,)).to(device) # log (initial state flow), Z = 1
    model_TB = make_mlp([input_dim] + [n_hid] * n_layers + [output_dim]).to(device)
    optimizer = torch.optim.Adam([ {'params':model_TB.parameters(), 'lr':0.001}, {'params':[logZ_TB], 'lr':0.1} ])
    logZ_TB.requires_grad_()

model_TB

In [None]:
# GFN hyperparams
n_train_steps = 10000
batch_size = 1
uniform_PB = False
minus_inf = -1e8

a = 100

In [None]:
losses_TB = []
rewards_TB = []
logZ_TB_list = []
all_visited_TB = []

In [None]:
for it in tqdm.trange(n_train_steps):    
    # TB loss for each trajectory
    loss_TB = torch.zeros((batch_size,)).to(device)
    loss_TB += 0 if independent_Z else logZ_TB # see the equation above
    # finished trajectories
    dones = torch.full((batch_size,), False, dtype=torch.bool).to(device)
    # s_0
    states = torch.zeros(size=(batch_size, ndim)).to(device) # (batch_size, ndim)
    # actions chosen at each step 
    actions = None # (current_batch_size,)

    max_steps = 1e8 #ndim+0
    i = 0
    while torch.any(~dones) and i <= max_steps :
        ### Forward pass ### 
        current_batch_size = (~dones).sum()
        non_terminal_states = states[~dones] # (current_batch_size, ndim)
        logits = model_TB(non_terminal_states) # (current_batch_size, output_dim)

        ### Backward Policy ### 
        PB_logits = logits[...,ndim+1:2*ndim+1] # (current_batch_size, ndim)
        PB_logits = PB_logits * (0 if uniform_PB else 1) # (current_batch_size, ndim)
        # Cells that are still black (0) are excluded from the action space of the backward policy
        PB_mask = (non_terminal_states == 0.).float() # (current_batch_size, ndim)
        logPB = (PB_logits + minus_inf*PB_mask).log_softmax(1) # (current_batch_size, ndim)
        if actions is not None: 
            loss_TB[~dones] -= logPB.gather(1, actions[actions!=ndim].unsqueeze(1)).squeeze(1)
        elif independent_Z :
            logZ_TB = logits[...,-1]
            loss_TB += logZ_TB + 0 
            logZ_TB = logZ_TB.mean()

        ### Forward Policy ### 
        PF_logits = logits[...,:ndim+1] # (current_batch_size, ndim+1) 
        # Cells that are already white (1) are excluded from the action space of the forward policy
        edge_mask = (non_terminal_states == 1.).float() # (current_batch_size, ndim)
        stop_action_mask = torch.zeros((current_batch_size, 1), device=device) # (current_batch_size, 1)
        PF_mask = torch.cat([edge_mask, stop_action_mask], 1) # (current_batch_size, ndim+1)
        logPF = (PF_logits + minus_inf*PF_mask).log_softmax(1) # (current_batch_size, ndim+1)
        sample_temperature = 1
        sample_ins_probs = (logPF/sample_temperature).softmax(1) # (current_batch_size, ndim+1)
        actions = sample_ins_probs.multinomial(1) # (current_batch_size,)
        loss_TB[~dones] += logPF.gather(1, actions).squeeze(1)

        ### select terminal states ### 
        terminates = (actions==ndim).squeeze(1)
        for state in non_terminal_states[terminates]: 
            all_visited_TB.append(state)
       
       # Update dones
        dones[~dones] |= terminates

        # Update non completed trajectories
        with torch.no_grad():
            non_terminates = actions[~terminates].squeeze()
            tmp = states[~dones]
            tmp[torch.arange((~dones).sum()), non_terminates] = 1.
            states[~dones] = tmp
        
        i+=1
        
    dist = torch.nn.functional.mse_loss(input=states.view(M,N), target=x.view(M,N))
    R = (-a*dist).exp()
    loss_TB -= R.log()
    loss = (loss_TB**2).sum()/batch_size

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses_TB.append(loss.item())
    rewards_TB.append(R.mean().cpu())
    logZ_TB_list.append(logZ_TB.item())

    if it%100==0: 
        print('\nloss =', np.array(losses_TB[-1:]).mean(), 'logZ =', logZ_TB.item(), "R =", np.array(rewards_TB[-1:]).mean())

In [None]:
with open("./trained_gfn_model/model_TB.pkl", "wb") as f:
    pkl.dump(model_TB, f)
with open("./trained_gfn_model/losses_TB.pkl", "wb") as f:
    pkl.dump(losses_TB, f)
with open("./trained_gfn_model/rewards_TB.pkl", "wb") as f:
    pkl.dump(rewards_TB, f)
with open("./trained_gfn_model/logZ_TB_list.pkl", "wb") as f:
    pkl.dump(logZ_TB_list, f)
with open("./trained_gfn_model/all_visited_TB.pkl", "wb") as f:
    pkl.dump(all_visited_TB, f)