# Lesson 4: Recurrent layers

*Teachers:* Fares Schulz, Lina Campanella

In this course we will cover:
1. Building a hybrid CNN-LSTM architecture for Bach chorale generation
2. Temperature-controlled sampling techniques for creative AI applications

## Generating Bach Chorales

Disclaimer: This example is in large part taken from Aurélien Gérons book Hands-On Machine Learning, you can find the original implementation [here](https://github.com/ageron/handson-ml2/blob/master/15_processing_sequences_using_rnns_and_cnns.ipynb).

For this exercise we will use the JSB Chorales dataset. It is composed of 382 chorales composed by Johann Sebastian Bach. Each chorale is 100 to 640 time steps long, and each time step contains 4 integers, where each integer corresponds to a midi note (except for the value 0, which means that no note is played). We will train a model with both convolutional and recurrent layers, that can predict the next time step (four notes), given a sequence of time steps from a chorale. Then use this model to generate Bach-like music, one note at a time: you can do this by giving the model the start of a chorale and asking it to predict the next time step, then appending these time steps to the input sequence and asking the model for the next note, and so on.

In [None]:
import urllib.request
import tarfile
from pathlib import Path
import pandas as pd

# Download the dataset using urllib and extract with tarfile
download_link = "https://github.com/iCorv/jsb-chorales-dataset/raw/main/jsb_chorales.tar"
data_dir = Path('resources/_data/jsb_chorales')
tar_path = data_dir / 'jsb_chorales.tar'

# Create directory if it doesn't exist
data_dir.mkdir(parents=True, exist_ok=True)

# Download the file if it doesn't already exist
if not tar_path.exists():
    print(f"Downloading dataset from {download_link}")
    urllib.request.urlretrieve(download_link, tar_path)
    print(f"Downloaded to {tar_path}")
else:
    print(f"Dataset already exists at {tar_path}")

# Extract the tar file
if tar_path.exists() and not (data_dir / 'jsb_chorales').exists():
    print(f"Extracting {tar_path}")
    with tarfile.open(tar_path, 'r') as tar:
        tar.extractall(path=data_dir)
    print(f"Extracted to {data_dir}")

filepath = str(tar_path)
print(f"Dataset available at: {filepath}")

In [None]:
jsb_chorales_dir = Path(filepath).parent
train_files = sorted(jsb_chorales_dir.glob("train/chorale_*.csv"))
valid_files = sorted(jsb_chorales_dir.glob("valid/chorale_*.csv"))
test_files = sorted(jsb_chorales_dir.glob("test/chorale_*.csv"))

In [None]:
def load_chorales(filepaths):
    return [pd.read_csv(filepath).values.tolist() for filepath in filepaths]

train_chorales = load_chorales(train_files)
valid_chorales = load_chorales(valid_files)
test_chorales = load_chorales(test_files)

In [None]:
notes = set()
for chorales in (train_chorales, valid_chorales, test_chorales):
    for chorale in chorales:
        for chord in chorale:
            notes |= set(chord)

n_notes = len(notes)
min_note = min(notes - {0})
max_note = max(notes)

assert min_note == 36
assert max_note == 81

 Notes range from 36 (C1 = C on octave 1) to 81 (A5 = A on octave 5), plus 0 for silence. This is what the data looks like:

In [None]:
# chorales are saved as a chord progressions
train_chorales[0][:10]

 We use this simple numpy synthesizer to play some of the chorales

In [None]:
from resources._code.synthesizer import SimpleSynth

baroque_synth = SimpleSynth(tempo=160, amplitude=0.1, sample_rate=44100, baroque_tuning=True)
devine_synth = SimpleSynth(tempo=160, amplitude=0.1, sample_rate=44100, baroque_tuning=False)

for idx in range(1):
    baroque_synth.play_chorale(train_chorales[idx])
    devine_synth.play_chorale(train_chorales[idx])

In order to be able to generate new chorales, we want to train a model that can predict the next chord given all the previous chords. If we naively try to predict the next chord in one shot, predicting all 4 notes at once, we run the risk of getting notes that don't go very well together. It's much better and simpler to predict one note at a time. So we will need to preprocess every chorale, turning each chord into an arpegio (i.e., a sequence of notes rather than notes played simultaneuously). So each chorale will be a long sequence of notes (rather than chords), and we can just train a model that can predict the next note given all the previous notes. We will use a sequence-to-sequence approach, where we feed a window to the neural net, and it tries to predict that same window shifted one time step into the future.

We will also shift the values so that they range from 0 to 46, where 0 represents silence, and values 1 to 46 represent notes 36 (C1) to 81 (A5).
And we will train the model on windows of 128 notes (i.e., 32 chords).

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

def preprocess(window):
    # Shift values: keep 0 as 0 (silence), shift other notes to start from 1
    window = torch.where(window == 0, window, window - min_note + 1)
    return window.reshape(-1)  # convert to arpeggio (flatten to 1D sequence)

class BachDataset(Dataset):
    def __init__(self, chorales, window_size=32, window_shift=16):
        self.chorales = chorales
        self.window_size = window_size
        self.window_shift = window_shift
        self.windows = self._create_windows()
    
    def _create_windows(self):
        windows = []
        for chorale in self.chorales:
            chorale_tensor = torch.tensor(chorale, dtype=torch.long)
            
            # Create sliding windows
            for i in range(0, len(chorale) - self.window_size, self.window_shift):
                window = chorale_tensor[i:i + self.window_size + 1]
                if len(window) == self.window_size + 1:  # Ensure full window
                    windows.append(window)
        
        return windows
    
    def __len__(self):
        return len(self.windows)
    
    def __getitem__(self, idx):
        window = self.windows[idx]
        # Preprocess: shift note values and flatten
        preprocessed = preprocess(window)
        
        # Create input/target pairs 
        X = preprocessed[:-1]
        Y = preprocessed[1:] # predict next note in each arpegio, at each step
        
        return X, Y

def bach_dataloader(chorales, batch_size=32, shuffle=False, window_size=32, window_shift=16):
    
    dataset = BachDataset(chorales, window_size, window_shift)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    
    return dataloader


In [None]:
# load the datasets
train_set = bach_dataloader(train_chorales, shuffle=True)
valid_set = bach_dataloader(valid_chorales)
test_set = bach_dataloader(test_chorales)

Now let's create the model:
- We could feed the note values directly to the model, as floats, but this would probably not give good results. Indeed, the relationships between notes are not that simple: for example, if you replace a C3 with a C4, the melody will still sound fine, even though these notes are 12 semi-tones apart (i.e., one octave). Conversely, if you replace a C3 with a C#3, it's very likely that the chord will sound horrible, despite these notes being just next to each other. So we will use an Embedding layer to convert each note to a small vector representation. We will use 5-dimensional embeddings, so the output of this first layer will have a shape of [batch_size, window_size, 5].
- We will then feed this data to a small WaveNet-like neural network, composed of a stack of 4 Conv1D layers with doubling dilation rates. We will intersperse these layers with BatchNormalization layers for faster better convergence.
- Then one LSTM layer to try to capture long-term patterns.
- And finally a Dense layer to produce the final note probabilities. It will predict one probability for each chorale in the batch, for each time step, and for each possible note (including silence). So the output shape will be [batch_size, window_size, 47].

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torchinfo

class Bach_Chorale_NN(nn.Module):
    def __init__(self, n_notes=47, n_embedding_dims=5):
        super().__init__()
        
        # Embedding layer
        self.embedding = nn.Embedding(n_notes, n_embedding_dims)
        
        # Conv1D layers with causal padding and batch normalization
        self.conv1 = nn.Conv1d(n_embedding_dims, 32, kernel_size=2, dilation=1)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(32, 48, kernel_size=2, dilation=2)
        self.bn2 = nn.BatchNorm1d(48)
        self.conv3 = nn.Conv1d(48, 64, kernel_size=2, dilation=4)
        self.bn3 = nn.BatchNorm1d(64)
        self.conv4 = nn.Conv1d(64, 96, kernel_size=2, dilation=8)
        self.bn4 = nn.BatchNorm1d(96)
        self.lstm = nn.LSTM(96, 256, batch_first=True) # LSTM layer
        self.Linear1 = nn.Linear(256, n_notes) # Output layer
        
    def causal_pad(self, x, kernel_size, dilation):
        padding = (kernel_size - 1) * dilation
        return F.pad(x, (padding, 0))
    
    def forward(self, x):
        # x shape: (batch_size, seq_len) with integer note indices
        
        # Embedding: (batch_size, seq_len) -> (batch_size, seq_len, embedding_dim)
        x = self.embedding(x)
        # Transpose for Conv1D: (batch_size, seq_len, embedding_dim) -> (batch_size, embedding_dim, seq_len)
        x = x.transpose(1, 2)
        # Conv1D layers with causal padding and batch normalization
        x = self.causal_pad(x, 2, 1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.causal_pad(x, 2, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.causal_pad(x, 2, 4)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.causal_pad(x, 2, 8)
        x = F.relu(self.bn4(self.conv4(x)))
        
        # Transpose back for LSTM: (batch_size, channels, seq_len) -> (batch_size, seq_len, channels)
        x = x.transpose(1, 2)
        lstm_out, _ = self.lstm(x)  # (batch_size, seq_len, 256)
        
        output = self.Linear1(lstm_out)  # (batch_size, seq_len, n_notes)
        
        return output

In [None]:
# Create model
model = Bach_Chorale_NN(n_notes=47, n_embedding_dims=5)

torchinfo.summary(model, input_data=torch.ones((32, 131), dtype=torch.long))

In [None]:
# Add loss function and device setup
device = torch.device('cpu') #'mps' if torch.backends.mps.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.NAdam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


# Training function
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (X, Y) in enumerate(dataloader):
        X, Y = X.to(device), Y.to(device)

        # Forward pass
        outputs = model(X)  # Shape: (batch, seq_len, n_notes)
        
        # Reshape for loss computation
        outputs_flat = outputs.view(-1, outputs.size(-1))
        Y_flat = Y.view(-1)
        
        loss = criterion(outputs_flat, Y_flat)
        
        optimizer.zero_grad()
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate accuracy
        _, predicted = outputs_flat.max(1)
        total += Y_flat.size(0)
        correct += predicted.eq(Y_flat).sum().item()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X, Y in dataloader:
            X, Y = X.to(device), Y.to(device)
            
            outputs = model(X)
            
            outputs_flat = outputs.view(-1, outputs.size(-1))
            Y_flat = Y.view(-1)
            
            loss = criterion(outputs_flat, Y_flat)
            total_loss += loss.item()
            
            _, predicted = outputs_flat.max(1)
            total += Y_flat.size(0)
            correct += predicted.eq(Y_flat).sum().item()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

# Training loop 
def train_model(model, train_loader, valid_loader, optimizer, criterion, epochs=15, device='cpu'):
    for epoch in range(epochs):
        # Training
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        
        # Validation
        val_loss, val_acc = validate(model, valid_loader, criterion, device)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print('-' * 50)


# Run training
train_model(model=model, train_loader=train_set, valid_loader=valid_set,  optimizer=optimizer, criterion=criterion, epochs=15, device=device)

Feel free to iterate on this model now and try to optimize it. For example, you could try removing the LSTM layer and replacing it with Conv1D layers. You could also play with the number of layers, the learning rate, the optimizer, and so on.

Once you're satisfied with the performance of the model on the validation set, you can evaluate it one last time on the test set:

In [None]:
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X, Y in test_loader:
            X, Y = X.to(device), Y.to(device)
            
            outputs = model(X)
            
            # Reshape for loss computation
            outputs_flat = outputs.view(-1, outputs.size(-1))
            Y_flat = Y.view(-1)
            
            loss = criterion(outputs_flat, Y_flat)
            total_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = outputs_flat.max(1)
            total += Y_flat.size(0)
            correct += predicted.eq(Y_flat).sum().item()
    
    avg_loss = total_loss / len(test_loader)
    accuracy = 100. * correct / total
    
    print(f'Test Results:')
    print(f'Test Loss: {avg_loss:.4f}')
    print(f'Test Accuracy: {accuracy:.2f}%')
    
    return avg_loss, accuracy

# Evaluate the model on test set
test_loss, test_acc = evaluate_model(model, test_set, criterion, device)

Now let's write a function that will generate a new chorale. We will give it a few seed chords, it will convert them to arpegios (the format expected by the model), and use the model to predict the next note, then the next, and so on. In the end, it will group the notes 4 by 4 to create chords again, and return the resulting chorale. 

In [None]:
def generate_chorale(model, seed_chords, length, device='cpu'):
    model.eval()  # Set model to evaluation mode
    
    with torch.no_grad():  # Disable gradient computation for inference
        # Convert seed chords to tensor and preprocess
        seed_tensor = torch.tensor(seed_chords, dtype=torch.long)
        arpegio = preprocess(seed_tensor)
        arpegio = arpegio.unsqueeze(0).to(device)  # Add batch dimension and move to device
        
        # Generate new notes
        for chord in range(length):
            for note in range(4):
                # Get model prediction for the current sequence
                outputs = model(arpegio)  # Shape: (1, seq_len, n_notes)
                
                # Get the prediction for the last timestep
                last_output = outputs[0, -1, :]  # Shape: (n_notes,)
                
                # Get the most likely next note
                next_note = torch.argmax(last_output, dim=-1, keepdim=True)  # Shape: (1,)
                
                # Append the predicted note to the sequence
                arpegio = torch.cat([arpegio, next_note.unsqueeze(0)], dim=1)
        
        # Convert back to original note range (reverse the preprocessing)
        arpegio = torch.where(arpegio == 0, arpegio, arpegio + min_note - 1)
        
        # Reshape to chord format (group every 4 notes)
        arpegio_flat = arpegio.squeeze(0)  # Remove batch dimension
        n_total_notes = len(arpegio_flat)
        n_complete_chords = n_total_notes // 4
        
        # Take only complete chords and reshape
        chorale = arpegio_flat[:n_complete_chords * 4].reshape(-1, 4)
        
        return chorale.numpy()  # Convert back to numpy for compatibility

To test this function, we need some seed chords. Let's use the first 12 chords of one of the test chorales (it's actually just 3 different chords, each played 4 times):

In [None]:
seed_chords = test_chorales[2][:12]
baroque_synth.play_chorale(seed_chords)

Now we are ready to generate our first chorale! Let's ask the function to generate 20 more chords, for a total of 32 chords, i.e., 8 bars (assuming 4 chords per bar, i.e., a 4/4 signature):

In [None]:
new_chorale = generate_chorale(model, seed_chords, 32)
baroque_synth.play_chorale(new_chorale)

This approach has one major flaw: it is often too conservative. Indeed, the model will not take any risk, it will always choose the note with the highest score, and since repeating the previous note generally sounds good enough, it's the least risky option, so the algorithm will tend to make notes last longer and longer. Pretty boring. Plus, if you run the model multiple times, it will always generate the same melody.

So let's spice things up a bit! Instead of always picking the note with the highest score, we will pick the next note randomly, according to the predicted probabilities. For example, if the model predicts a C3 with 75% probability, and a G3 with a 25% probability, then we will pick one of these two notes randomly, with these probabilities. We will also add a temperature parameter that will control how "hot" (i.e., daring) we want the system to feel. A high temperature will bring the predicted probabilities closer together, reducing the probability of the likely notes and increasing the probability of the unlikely ones.

In [None]:
from torch.distributions import Categorical

def generate_chorale_v2(model, seed_chords, length, temperature=1, device='cpu'):
    model.eval()  # Set model to evaluation mode
    
    with torch.no_grad():  # Disable gradient computation for inference
        # Convert seed chords to tensor and preprocess
        seed_tensor = torch.tensor(seed_chords, dtype=torch.long)
        arpegio = preprocess(seed_tensor)
        arpegio = arpegio.unsqueeze(0).to(device)  # Add batch dimension and move to device

        for chord in range(length):
            for note in range(4):

                outputs = model(arpegio)  # Shape: (1, seq_len, n_notes)
                
                # Get the prediction for the last timestep
                last_output = outputs[0, -1, :]  # Shape: (n_notes,)   

                rescaled_logits = last_output / temperature

                categorical = Categorical(logits=rescaled_logits)
                next_note = categorical.sample().unsqueeze(0) # Shape: (1,)
               
                # Append the predicted note to the sequence
                arpegio = torch.cat([arpegio, next_note.unsqueeze(0)], dim=1)
    
    # Convert back to original note range (reverse the preprocessing)
    arpegio = torch.where(arpegio == 0, arpegio, arpegio + min_note - 1)
    
    # Reshape to chord format (group every 4 notes)
    arpegio_flat = arpegio.squeeze(0)  # Remove batch dimension
    n_total_notes = len(arpegio_flat)
    n_complete_chords = n_total_notes // 4
    
    # Take only complete chords and reshape
    chorale = arpegio_flat[:n_complete_chords * 4].reshape(-1, 4)
    
    return chorale.numpy()  # Convert back to numpy for compatibility

In [None]:
cold_chorale = generate_chorale_v2(model, seed_chords, 32, temperature=0.8)
baroque_synth.play_chorale(cold_chorale)

In [None]:
medium_chorale = generate_chorale_v2(model, seed_chords, 32, temperature=1)
baroque_synth.play_chorale(medium_chorale)

In [None]:
hot_chorale = generate_chorale_v2(model, seed_chords, 32, temperature=1.5)
baroque_synth.play_chorale(hot_chorale)