In [1]:
import torch
if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"
device

'cuda:0'

In [2]:
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from data_modules import data_modules as dm

In [3]:
class ActorNetwork(nn.Module):
    def __init__(self, feature_extractor_network: nn.Module, feature_dims : int, parameter_vector_dims : int):
        super().__init__()
        self.feature_extractor_network = feature_extractor_network
        self.actor_head = nn.Sequential(
            # Note:
            # input num dims = 2 * feature_dims + parameter_vector_dims
            # because actor head will get a concatenation of the following:
            #    - feature vector of target spectrogram
            #    - feature vector of prior spectrogram generated from the prior predicted parameter vector
            #    - feature vector of prior predicted parameter vector
            nn.Linear(2 * feature_dims + parameter_vector_dims, parameter_vector_dims*2),
            nn.ELU(),
            nn.Linear(parameter_vector_dims*2, parameter_vector_dims),
            nn.ELU(),
            nn.Linear(parameter_vector_dims, parameter_vector_dims),

            # This final sigmoid activation ensures each param output is in the [0,1] range
            nn.Sigmoid()
        )

    def forward(self, target_spectrogram, prior_spectrogram, prior_parameter_vector):
        # Extract spectrogram features
        target_spectrogram_feats = self.feature_extractor_network(target_spectrogram)
        prior_spectrogram_feats = self.feature_extractor_network(prior_spectrogram)
        state_vec = torch.hstack((target_spectrogram_feats, prior_spectrogram_feats, prior_parameter_vector))
        
        # Get action from actor network
        pi = self.actor_head(state_vec)
        
        return pi

In [4]:
class CriticNetwork(nn.Module):
    def __init__(self, feature_extractor_network: nn.Module, feature_dims : int, parameter_vector_dims : int):
        super().__init__()
        self.feature_extractor_network = feature_extractor_network
                
        self.critic_head = nn.Sequential(
            # See prior note for explanation behind number of input dims here
            nn.Linear(2 * feature_dims + parameter_vector_dims, 256*2),
            nn.ELU(),
            nn.Linear(256*2, 256),
            nn.ELU(),
            nn.Linear(256, 1),
            
            # We use a final tanh for the critic head's output value, to squash it between 1 and -1
            # nn.Tanh()
        )
        
    def forward(self, target_spectrogram, prior_spectrogram, prior_parameter_vector):
        # Extract spectrogram features
        # print(f"Target {target_spectrogram.get_device()}")
        # print(f"Prior {prior_spectrogram.get_device()}")
        target_spectrogram_feats = self.feature_extractor_network(target_spectrogram)
        prior_spectrogram_feats = self.feature_extractor_network(prior_spectrogram)
        state_vec = torch.hstack((target_spectrogram_feats, prior_spectrogram_feats, prior_parameter_vector))
        
        # Get predicted Q(s,a) value from critic network
        value = self.critic_head(state_vec)
        
        return value

In [5]:
# Dummy spectrogram
# spectrogram_shape = (1, 4, 8)
# parameter_vector_dims = 16 #np.prod(spectrogram_shape)
feature_vector_dims = 1024

# Actual spectrogram
preset_data = dm.TargetSoundDataModule(
    data_dir=os.path.join('data', 'preset_data'),
    split_file='split_dict.json',
    num_workers=3,
    batch_size = 1,
    shuffle=True,
    return_params = True
)
preset_data.setup()
t0 = next(iter(preset_data.train_dataloader()))
print(t0.keys()) # Check that params are also returned
# Set up stuff
audiohandler = dm.AudioHandler()

dict_keys(['audio', 'params', 'spectrogram'])


In [6]:
# IGNORE - For dummy spectrogram

# [parameter vector dims, spectrogram flattened num dims]
# parameter_vector_to_spectrogram_matrix = ortho_group.rvs(dim=np.prod(spectrogram_shape)).astype(np.float32)
# parameter_vector_to_spectrogram_matrix = torch.from_numpy(parameter_vector_to_spectrogram_matrix).to(device).requires_grad_(False) 
# parameter_vector_to_spectrogram_matrix = torch.randn(parameter_vector_dims, np.prod(spectrogram_shape), requires_grad=False, device=device)
# parameter_vector_to_spectrogram_matrix.shape

In [7]:
# IGNORE - For dummy spectrogram
# def generate_spectrogram_from_parameter_vector(parameter_vector):
#     # [N, spectrogram flattened num dims]
#     spectrogram_flattened = parameter_vector @ parameter_vector_to_spectrogram_matrix
#     spectrogram = torch.reshape(spectrogram_flattened, (-1,) + spectrogram_shape)
#     return spectrogram

## Demo training

In [8]:
from torch.optim import Adam
from tqdm.auto import tqdm

In [9]:
# IGNORE - For dummy spectrogram
# ground_truth_params_vector = torch.randn(1, parameter_vector_dims, device=device)
# ground_truth_spectrogram = generate_spectrogram_from_parameter_vector(ground_truth_params_vector)
# ground_truth_params_vector.shape, ground_truth_spectrogram.shape

ground_truth_params_vector = t0['params']
ground_truth_spectrogram = t0['spectrogram']
ground_truth_audio = t0['audio']
parameter_vector_dims = np.prod(ground_truth_params_vector.shape)
ground_truth_params_vector.shape, ground_truth_spectrogram.shape, parameter_vector_dims

(torch.Size([1, 155]), torch.Size([1, 257, 345]), 155)

In [10]:
# These should be close if parameter_vector_to_spectrogram_matrix is orthonormal
# ground_truth_params_vector, (ground_truth_spectrogram.reshape(1, -1) @ parameter_vector_to_spectrogram_matrix.T)

In [11]:
# IGNORE - For dummy spectrogram
def compute_mse_between_spectrograms(spec1, spec2):
    # output: [N,1]
    return ((spec1 - spec2)**2).mean(axis=(0,1))

In [12]:
# IGNORE - For dummy spectrogram
# def compute_reward(spec1, spec2):
#     # output: [N,1]
    
#     # first, we compute mse loss and use that to construct a reward
#     mse_loss = compute_mse_between_spectrograms(spec1, spec2)
    
#     # with the transformation below, higher mse loss -> lower reward, and lower mse loss -> higher reward
#     reward = -mse_loss #1/mse_loss #-torch.log(mse_loss) #1/mse_loss
    
#     # finally, to keep reward within [-1,1] range, we use tanh
#     # reward = torch.tanh(reward)
    
#     return reward

# Actual spectrogram reward
def compute_reward(spec1, spec2):
    # output: [N,1]
    
    # first, we compute mse loss and use that to construct a reward
    mse_loss = compute_mse_between_spectrograms(spec1, spec2)
    
    # with the transformation below, higher mse loss -> lower reward, and lower mse loss -> higher reward
    reward = -mse_loss #1/mse_loss #-torch.log(mse_loss) #1/mse_loss
    
    # finally, to keep reward within [-1,1] range, we use tanh
    # reward = torch.tanh(reward)
    
    return reward

In [13]:
def compute_returns(rewards, gamma=0.99):
    """
    https://github.com/yc930401/Actor-Critic-pytorch/blob/5d359bc93839357255b591c0a87fc42542854eb8/Actor-Critic.py#L50
    """
    returns = []
    
    R = rewards[-1]

    for idx, step in enumerate(reversed(range(len(rewards)))):
        if idx == 0:
            R = rewards[step]
        else:
            R = rewards[step] + gamma * R
        returns.insert(0, R)
    return returns

In [14]:
class DummySpectrogramFeatureExtractorNetwork(nn.Module):
    """
    This is a dummy placeholder spectrogram feature extractor, that will convert any input spectrogram into a output feature vector.
    The pre-trained VAE, for example, would ideally replace this.
    """
    
    def __init__(self, spectrogram_shape : tuple[int], feature_dims : int):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(spectrogram_shape), feature_dims),
            nn.ReLU(),
            nn.Linear(feature_dims, feature_dims)
        )
    
    def forward(self, spectrogram):
        # flatten spectrogram into [N, flattened dims]
        spectrogram_flattened = torch.flatten(spectrogram, start_dim=0)
        return self.model(spectrogram_flattened)

In [15]:
# def get_next_state(predicted_parameter_vector):
#     next_spectrogram = generate_spectrogram_from_parameter_vector(predicted_parameter_vector)
#     return [ground_truth_spectrogram, next_spectrogram, predicted_parameter_vector]

def get_next_state(ground_truth_spectrogram, predicted_parameter_vector, audiohandler, device):
    audio, next_spectrogram = audiohandler.generateAudio(predicted_parameter_vector)
    return [ground_truth_spectrogram.to(device), next_spectrogram.to(device), predicted_parameter_vector.to(device)]

In [16]:
# Use dummy feature extractor

In [17]:
ACTOR_LR = 1e-9
CRITIC_LR = 1e-8

num_episodes = 1_000
num_steps_per_episode = 1_00
gamma = 0.99

GRADIENT_CLIPPING = 0.01

step_optimizer_every_n_episode_steps = 10

In [18]:
feature_extractor_network = DummySpectrogramFeatureExtractorNetwork(spectrogram_shape=ground_truth_spectrogram.shape, feature_dims=feature_vector_dims)
_ = feature_extractor_network.to(device)

In [19]:
actor_model = ActorNetwork(feature_extractor_network=feature_extractor_network, feature_dims=feature_vector_dims, parameter_vector_dims=parameter_vector_dims).to(device)
critic_model = CriticNetwork(feature_extractor_network=feature_extractor_network, feature_dims=feature_vector_dims, parameter_vector_dims=parameter_vector_dims).to(device)

In [20]:
actor_optimizer = Adam(actor_model.parameters(), lr=ACTOR_LR)
critic_optimizer = Adam(critic_model.parameters(), lr=CRITIC_LR)

In [21]:
episode_pbar = tqdm(range(num_episodes))
for episode_idx in episode_pbar: #range(num_episodes):
    # Set initial state for episode
    prior_predicted_parameter_vector = torch.rand(parameter_vector_dims, device=device)
    audio, prior_predicted_spectrogram = audiohandler.generateAudio(prior_predicted_parameter_vector.cpu().detach())
    
    state = (ground_truth_spectrogram.to(device), prior_predicted_spectrogram.to(device), prior_predicted_parameter_vector.to(device))
    
    all_rewards = []
    
    episode_step_pbar = tqdm(range(num_steps_per_episode))
    for step_idx in episode_step_pbar:
        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        
        # Get predicted action and value for current state
        predicted_parameter_vector = actor_model(*state)
        value_of_current_state = critic_model(*state)
        print(f"predicted_parameter_vector {predicted_parameter_vector}")
        print(f"value_of_current_state {value_of_current_state}")
        # Perform action to get next state and reward
        next_state = get_next_state(ground_truth_spectrogram.cpu(), predicted_parameter_vector.cpu().detach(), audiohandler, device=device)
        reward = compute_reward(ground_truth_spectrogram.cpu().squeeze().detach().numpy(), next_state[1].cpu().squeeze().detach().numpy() )
        
        value_next_state = critic_model(*next_state)
        # print(value_next_state.shape)
        #TD Target: r + gamma * V(next_state)
        td_target = reward + gamma * value_next_state
        
        # TODO: NaN loss, bugged
        print(f"value_next_state {value_next_state}")
        delta = td_target - value_of_current_state
        print(f"reward {reward}")
        actor_loss = delta * -torch.log(predicted_parameter_vector).mean()
        critic_loss = delta ** 2
        print(f"actor loss {actor_loss}")
        print(f"critic loss {critic_loss}")
        # We're weighting critic loss by 0.5
        critic_loss_weight = 0.5
        loss = actor_loss + critic_loss_weight * critic_loss
        print(loss)
        loss.backward()
        
        if step_idx % step_optimizer_every_n_episode_steps == 0:
            if GRADIENT_CLIPPING is not None:
                torch.nn.utils.clip_grad_norm_(actor_model.parameters(), GRADIENT_CLIPPING)
                torch.nn.utils.clip_grad_norm_(critic_model.parameters(), GRADIENT_CLIPPING)

            actor_optimizer.step()
            critic_optimizer.step()
        
        # Now that we've learned from this step, update current state variable
        next_state[1] = next_state[1].detach()
        next_state[2] = next_state[2].detach()
        state = next_state
        
        all_rewards.append(reward.item())
        # print(predicted_parameter_vector)
        # print(ground_truth_params_vector)
        with torch.no_grad():
            params_mse = ((predicted_parameter_vector.to(device) - ground_truth_params_vector.to(device))**2).mean().item()
        episode_step_pbar.set_postfix({"loss" : loss.item(), "reward" : reward.item(), "params_mse" : params_mse})
        break
    break
    # print (f"Episode {episode_idx + 1}: total reward: {sum(all_rewards)}, return: {compute_returns(all_rewards, gamma)[0]}")
    episode_pbar.set_postfix({"total reward" : sum(all_rewards), "return" : compute_returns(all_rewards, gamma)[0]})

  0%|          | 0/100 [00:01<?, ?it/s, loss=nan, reward=nan, params_mse=0.257]
  0%|          | 0/1000 [00:02<?, ?it/s]

predicted_parameter_vector tensor([0.9503, 0.0879, 0.4526, 0.3286, 0.0894, 0.8229, 0.0935, 0.9525, 0.9960,
        0.6931, 0.3791, 0.3311, 0.4486, 0.9582, 0.6784, 0.8286, 0.9353, 0.9366,
        0.2177, 0.9026, 0.8711, 0.0658, 0.4316, 0.7656, 0.7646, 0.0814, 0.7100,
        0.6752, 0.9239, 0.8133, 0.1706, 0.7685, 0.4450, 0.2497, 0.9927, 0.0041,
        0.1987, 0.5134, 0.7998, 0.7241, 0.0659, 0.3664, 0.8363, 0.6985, 0.4658,
        0.1252, 0.3599, 0.9907, 0.8322, 0.4194, 0.6232, 0.0367, 0.2499, 0.0732,
        0.5058, 0.0015, 0.3950, 0.4055, 0.2156, 0.5820, 0.1767, 0.2538, 0.1763,
        0.4342, 0.8024, 0.0425, 0.6188, 0.0759, 0.1770, 0.1257, 0.9688, 0.1353,
        0.6636, 0.6895, 0.1389, 0.2710, 0.3197, 0.3680, 0.0430, 0.7992, 0.2915,
        0.6067, 0.0291, 0.2426, 0.2113, 0.0395, 0.9691, 0.2020, 0.9298, 0.4228,
        0.5649, 0.9000, 0.1025, 0.7228, 0.7514, 0.8766, 0.2069, 0.9645, 0.7439,
        0.6479, 0.7932, 0.0511, 0.8922, 0.8471, 0.1157, 0.2993, 0.8362, 0.4723,
        0.998


