# MUSIC GAN

## INITIALIZATION

### Imports

In [43]:
import music21
import os
import pickle as pkl
import keras
import tensorflow as tf
from typing import List, Tuple, Optional
import requests
import glob
from tensorflow.keras import(
    layers
)

### Const

In [22]:
PARSE_MIDI_FILES= True
PARSED_DATA_PATH='./datasets/bach-cello-parseed/'
DATA_PATH='./datasets/bach-cello/'
DATASET_REPETTIOTIONS= 1
SEQ_LEN= 50
EMBEDDING_DIM= 256
KEY_DIM= 256
N_HEADS= 5
DROPOUT_RATE= 0.3
FEED_FORWARD_DIM= 256
LOAD_MODEL= False

# Optimization
EPOCHS= 5000
BATCH_SIZE= 256
GENERATE_LEN= 50

### Functions

In [52]:
def parse_midi_files(
    file_list: List[str],
    parser: music21.converter.Converter,
    seq_len: int,
    parsed_data_path: Optional[str] = None
) -> Tuple[List[str], List[str]]:
    """
    Parses a list of MIDI files and extracts note sequences and their corresponding durations.

    Args:
        file_list (List[str]): A list of file paths to MIDI files to be parsed.
        parser (music21.converter.Converter): A music21 parser object to read and process the MIDI files.
        seq_len (int): The length of the note and duration sequences to be generated.
        parsed_data_path (Optional[str]): Path to save the parsed note and duration sequences (if provided).

    Returns:
        Tuple[List[str], List[str]]: A tuple containing two lists:
            - notes_list: A list of note sequences of length `seq_len`.
            - duration_list: A list of corresponding duration sequences of length `seq_len`.
    """
    notes_list = []  # List to store sequences of notes
    duration_list = []  # List to store sequences of durations
    notes = []  # Temporary list to hold notes for a single file
    durations = []  # Temporary list to hold durations for a single file

    # Loop through each MIDI file in the file list
    for i, file in enumerate(file_list):
        print(i + 1, f'Parsing {file}')
        score = parser.parse(file).chordify()  # Convert the score to chords

        # Add a start token to the sequence
        notes.append('START')
        durations.append('0.0')

        # Iterate over the flattened elements of the score
        for element in score.flat:
            note_name = None
            duration_name = None

            # Determine the type of musical element and extract the note and duration
            if isinstance(element, music21.key.Key):
                note_name = f"{element.tonic.name}:{element.mode}"  # Extract key signature
                duration_name = '0.0'
            elif isinstance(element, music21.meter.TimeSignature):
                note_name = f"{element.ratioString}TS"  # Extract time signature
                duration_name = '0.0'
            elif isinstance(element, music21.chord.Chord):
                note_name = element.pitches[-1].nameWithOctave  # Use the highest pitch in the chord
                duration_name = str(element.duration.quarterLength)
            elif isinstance(element, music21.note.Rest):
                note_name = str(element.name)  # Extract rest
                duration_name = str(element.duration.quarterLength)
            elif isinstance(element, music21.note.Note):
                note_name = element.nameWithOctave  # Extract note with octave
                duration_name = str(element.duration.quarterLength)

            # Append note and duration if both were successfully extracted
            if note_name and duration_name:
                notes.append(note_name)
                durations.append(duration_name)

        print(f'{len(notes)} notes parsed')  # Log the number of parsed notes

    # Generate sequences of notes and durations of length `seq_len`
    print(f'Building sequences of length {seq_len}')
    for i in range(len(notes) - seq_len):
        notes_list.append(' '.join(notes[i: i + seq_len]))
        duration_list.append(' '.join(durations[i: i + seq_len]))

    # Create the directory if it doesn't exist
    os.makedirs(parsed_data_path, exist_ok=True)

    # Save the parsed sequences to files if `parsed_data_path` is provided
    if parsed_data_path:
        with open(os.path.join(parsed_data_path, 'notes.pkl'), 'wb') as f:
            pkl.dump(notes_list, f)
        with open(os.path.join(parsed_data_path, 'durations.pkl'), 'wb') as f:
            pkl.dump(duration_list, f)

    return notes_list, duration_list

def load_parsed_files(parsed_data_path: str) -> Tuple[List[str], List[str]]:
    """
    Loads the parsed note and duration sequences from pickle files.

    Args:
        parsed_data_path (str): The directory path where the parsed files ('notes' and 'durations') are stored.

    Returns:
        Tuple[List[str], List[str]]: A tuple containing:
            - notes (List[str]): A list of note sequences.
            - durations (List[str]): A list of corresponding duration sequences.
    """
    # Load the note sequences from the pickle file
    with open(os.path.join(parsed_data_path, 'notes'), 'rb') as f:
        notes = pkl.load(f)

    # Load the duration sequences from the pickle file
    with open(os.path.join(parsed_data_path, 'durations'), 'rb') as f:
        durations = pkl.load(f)

    return notes, durations

def get_midi_notes(sample_note: str, sample_duration: str) -> music21.note.NotRest:
    """
    Converts a sample note and its duration into a corresponding Music21 note, chord, rest, or key signature object.

    Args:
        sample_note (str): The name of the note or musical element (e.g., 'C4', 'rest', 'C4.E4.G4' for a chord, or '4/4TS').
        sample_duration (str): The duration of the note in terms of quarter length, represented as a string (e.g., '0.5', '1', '1.5').

    Returns:
        new_note (music21.note.Note, music21.chord.Chord, music21.key.Key, or music21.meter.TimeSignature): 
        A Music21 object corresponding to the provided note and duration.
    """
    new_note = None

    # Handle Time Signature (e.g., '4/4TS')
    if 'TS' in sample_note:
        new_note = music21.meter.TimeSignature(sample_note.split('TS')[0])

    # Handle Key Signature (e.g., 'C:major', 'A:minor')
    elif 'major' in sample_note or 'minor' in sample_note:
        tonic, mode = sample_note.split(':')
        new_note = music21.key.Key(tonic, mode)

    # Handle Rest
    elif sample_note == 'rest':
        new_note = music21.note.Rest()
        new_note.duration = music21.duration.Duration(float(Fraction(sample_duration)))
        new_note.storedInstrument = music21.instrument.Violoncello()

    # Handle Chord (e.g., 'C4.E4.G4')
    elif '.' in sample_note:
        notes_in_chord = sample_note.split('.')
        chord_notes = []

        for current_note in notes_in_chord:
            n = music21.note.Note(current_note)
            n.duration = music21.duration.Duration(float(Fraction(sample_duration)))
            n.storedInstrument = music21.instrument.Violoncello()
            chord_notes.append(n)

        new_note = music21.chord.Chord(chord_notes)

    # Handle Single Note (e.g., 'C4')
    elif sample_note != 'START':
        new_note = music21.note.Note(sample_note)
        new_note.duration = music21.duration.Duration(float(Fraction(sample_duration)))
        new_note.storedInstrument = music21.instrument.Violoncello()

    return new_note

def create_dataset(elements: List[str]) -> Tuple[tf.data.Dataset, layers.TextVectorization, List[str]]:
    """
    Creates a TensorFlow dataset from the input elements, batches and shuffles it,
    and applies a TextVectorization layer to convert text data into integer sequences.

    Args:
        elements (List[str]): A list of text data. Each element in the list represents a text sample.

    Returns:
        Tuple[tf.data.Dataset, layers.TextVectorization, List[str]]: A tuple containing:
            - ds (tf.data.Dataset): A batched and shuffled TensorFlow dataset.
            - vectorize_layer (layers.TextVectorization): A TextVectorization layer fitted to the dataset.
            - vocab (List[str]): The vocabulary extracted from the TextVectorization layer.
    
    Example:
        elements = ["hello world", "this is a test", "deep learning with tensorflow"]
        ds, vectorize_layer, vocab = create_dataset(elements)
    """
    
    # Convert the input elements into a TensorFlow dataset
    ds = (
        tf.data.Dataset.from_tensor_slices(elements)  # Create dataset slices from input list
        .batch(BATCH_SIZE, drop_remainder=True)       # Batch the dataset with specified batch size, dropping the remainder
        .shuffle(1000)                                # Shuffle the dataset with a buffer size of 1000
    )

    # Initialize a TextVectorization layer to convert text into sequences of integers
    vectorize_layer = layers.TextVectorization(
        standardize=None,    # No text preprocessing or standardization
        output_mode='int'    # Output mode is integer indices for each token
    )

    # Fit the TextVectorization layer to the dataset to build the vocabulary
    vectorize_layer.adapt(ds)

    # Extract the vocabulary list from the TextVectorization layer
    vocab = vectorize_layer.get_vocabulary()  # Vocabulary is a list of words or tokens
    
    # Return the dataset, the fitted TextVectorization layer, and the vocabulary
    return ds, vectorize_layer, vocab

def prepare_inputs(notes: tf.Tensor, durations: tf.Tensor) -> Tuple[Tuple[tf.Tensor, tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]:
    """
    Prepares input and target sequences for training by tokenizing notes and durations.

    Args:
        notes (tf.Tensor): A tensor of musical notes with shape (batch_size, sequence_length).
        durations (tf.Tensor): A tensor of note durations with shape (batch_size, sequence_length).

    Returns:
        Tuple[Tuple[tf.Tensor, tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]: A tuple containing:
            - x (Tuple[tf.Tensor, tf.Tensor]): The input tensors (tokenized notes and durations) with the last token removed.
            - y (Tuple[tf.Tensor, tf.Tensor]): The target tensors (tokenized notes and durations) with the first token removed.

    Example:
        Suppose `notes` is a tensor of shape (batch_size, sequence_length) representing sequences of musical notes,
        and `durations` is a tensor of the same shape representing note durations. This function tokenizes them and
        prepares the data for input-output pairs to be used in a sequence-to-sequence model.
    """
    
    # Expand the dimensions of notes and durations to match the expected input shape for vectorization.
    # Adds an extra dimension at the end, making the shape (batch_size, sequence_length, 1).
    notes = tf.expand_dims(notes, -1)
    durations = tf.expand_dims(durations, -1)

    # Tokenize the expanded note sequences using a predefined vectorization layer.
    tokenized_notes = notes_vectorize_layer(notes)

    # Tokenize the expanded duration sequences using a predefined vectorization layer.
    tokenized_durations = durations_vectorize_layer(durations)

    # Prepare the input sequences by removing the last token from each sequence.
    # `x` is the input tuple: (tokenized_notes[:, :-1], tokenized_durations[:, :-1])
    x = (tokenized_notes[:, :-1], tokenized_durations[:, :-1])

    # Prepare the target sequences by removing the first token from each sequence.
    # `y` is the target tuple: (tokenized_notes[:, 1:], tokenized_durations[:, 1:])
    y = (tokenized_notes[:, 1:], tokenized_durations[:, 1:])

    # Return the input and target pairs for training a sequence-to-sequence model.
    return x, y

### Classes

In [None]:
class SinePositionEncoding(keras.layers.Layer):
    """
    Sinusoidal positional encoding layer.

    This layer computes positional encodings using a mixture of sine and cosine
    functions with geometrically increasing wavelengths, as described in the paper
    'Attention Is All You Need' by Vaswani et al. (2017).

    Args:
        max_wavelength (int): The maximum angular wavelength of the sine and cosine
                              functions. Default is 10,000.

    Input:
        inputs (tf.Tensor): A 3D tensor of shape [batch_size, sequence_length, hidden_size],
                            where hidden_size is the feature dimension.

    Output:
        tf.Tensor: A positional encoding tensor with the same shape as the input.

    Example:
        ```python
        seq_len = 100
        vocab_size = 1000
        embedding_dim = 32

        inputs = keras.Input((seq_len,), dtype=tf.float32)
        embedding = keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)(inputs)
        positional_encoding = SinePositionEncoding()(embedding)
        outputs = embedding + positional_encoding
        ```
    """

    def __init__(self, max_wavelength: int = 10000, **kwargs):
        """
        Initialize the SinePositionEncoding layer.

        Args:
            max_wavelength (int): The maximum angular wavelength of the sine and cosine
                                  curves used for positional encoding.
            **kwargs: Additional keyword arguments passed to the base `keras.layers.Layer` class.
        """
        super().__init__(**kwargs)
        self.max_wavelength = max_wavelength

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Compute the sinusoidal positional encodings for the given input tensor.

        Args:
            inputs (tf.Tensor): A 3D tensor of shape [batch_size, sequence_length, hidden_size].

        Returns:
            tf.Tensor: A 3D tensor of shape [batch_size, sequence_length, hidden_size]
                       containing the sinusoidal positional encodings.
        """
        # Get the input shape dynamically
        input_shape = tf.shape(inputs)

        # Sequence length is the second-to-last dimension of the input
        seq_length = input_shape[-2]
        # Hidden size (embedding dimension) is the last dimension of the input
        hidden_size = input_shape[-1]

        # Create a tensor of positions from 0 to seq_length - 1, cast to the layer's dtype
        position = tf.cast(tf.range(seq_length), self.compute_dtype)

        # Compute the minimum frequency for the sinusoidal functions
        min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)

        # Compute the scaling timescales for each hidden dimension
        # Shape: [hidden_size]
        timescales = tf.pow(
            min_freq,
            tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype)
            / tf.cast(hidden_size, self.compute_dtype),
        )

        # Compute the angles (position * timescales) for sine and cosine functions
        # Shape: [seq_length, hidden_size]
        angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0)

        # Create masks for sine (even indices) and cosine (odd indices)
        cos_mask = tf.cast(tf.range(hidden_size) % 2, self.compute_dtype)  # [0, 1, 0, 1, ...]
        sin_mask = 1 - cos_mask  # [1, 0, 1, 0, ...]

        # Compute positional encodings by applying sine to even indices and cosine to odd indices
        # Shape: [seq_length, hidden_size]
        positional_encodings = (
            tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
        )

        # Broadcast the positional encodings to match the input shape
        return tf.broadcast_to(positional_encodings, input_shape)

    def get_config(self) -> dict:
        """
        Return the configuration of the layer for serialization.

        Returns:
            dict: A dictionary containing the layer's configuration.
        """
        config = super().get_config()
        config.update({"max_wavelength": self.max_wavelength})
        return config
        

### Dataset

In [None]:
# Create directory if it doesn't exist
output_dir = "data/bach-cello"
os.makedirs(output_dir, exist_ok=True)

# List of URLs to download
urls = [
    "http://www.jsbach.net/midi/cs1-1pre.mid",
    "http://www.jsbach.net/midi/cs1-2all.mid",
    "http://www.jsbach.net/midi/cs1-3cou.mid",
    "http://www.jsbach.net/midi/cs1-4sar.mid",
    "http://www.jsbach.net/midi/cs1-5men.mid",
    "http://www.jsbach.net/midi/cs1-6gig.mid",
    "http://www.jsbach.net/midi/cs2-1pre.mid",
    "http://www.jsbach.net/midi/cs2-2all.mid",
    "http://www.jsbach.net/midi/cs2-3cou.mid",
    "http://www.jsbach.net/midi/cs2-4sar.mid",
    "http://www.jsbach.net/midi/cs2-5men.mid",
    "http://www.jsbach.net/midi/cs2-6gig.mid",
    "http://www.jsbach.net/midi/cs3-1pre.mid",
    "http://www.jsbach.net/midi/cs3-2all.mid",
    "http://www.jsbach.net/midi/cs3-3cou.mid",
    "http://www.jsbach.net/midi/cs3-4sar.mid",
    "http://www.jsbach.net/midi/cs3-5bou.mid",
    "http://www.jsbach.net/midi/cs3-6gig.mid",
    "http://www.jsbach.net/midi/cs4-1pre.mid",
    "http://www.jsbach.net/midi/cs4-2all.mid",
    "http://www.jsbach.net/midi/cs4-3cou.mid",
    "http://www.jsbach.net/midi/cs4-4sar.mid",
    "http://www.jsbach.net/midi/cs4-5bou.mid",
    "http://www.jsbach.net/midi/cs4-6gig.mid",
    "http://www.jsbach.net/midi/cs5-1pre.mid",
    "http://www.jsbach.net/midi/cs5-2all.mid",
    "http://www.jsbach.net/midi/cs5-3cou.mid",
    "http://www.jsbach.net/midi/cs5-4sar.mid",
    "http://www.jsbach.net/midi/cs5-5gav.mid",
    "http://www.jsbach.net/midi/cs5-6gig.mid",
    "http://www.jsbach.net/midi/cs6-1pre.mid",
    "http://www.jsbach.net/midi/cs6-2all.mid",
    "http://www.jsbach.net/midi/cs6-3cou.mid",
    "http://www.jsbach.net/midi/cs6-4sar.mid",
    "http://www.jsbach.net/midi/cs6-5gav.mid",
    "http://www.jsbach.net/midi/cs6-6gig.mid"
]

# Function to download a file
def download_file(url, output_dir):
    file_name = os.path.join(output_dir, os.path.basename(url))
    print(f"Downloading {file_name}...")
    response = requests.get(url)
    if response.status_code == 200:
        with open(file_name, 'wb') as f:
            f.write(response.content)
        print(f"✅ Downloaded: {file_name}")
    else:
        print(f"❌ Failed to download: {file_name} (Status: {response.status_code})")

# Download each file in the list
#for url in urls:
#    download_file(url, output_dir)

In [30]:
file_list= glob.glob(DATA_PATH + '*.mid')
parser= music21.converter

In [35]:
example_score = (
    music21.converter.parse(file_list[1]).splitAtQuarterLength(12)[0].chordify()
)

In [37]:
example_score.show("text")

{0.0} <music21.metadata.Metadata object at 0x7fe076878150>
{0.0} <music21.stream.Measure 1 offset=0.0>
    {0.0} <music21.instrument.Violoncello 'Solo Cello: Solo Cello'>
    {0.0} <music21.instrument.Violoncello 'Violoncello'>
    {0.0} <music21.clef.BassClef>
    {0.0} <music21.tempo.MetronomeMark Quarter=250>
    {0.0} <music21.key.Key of G major>
    {0.0} <music21.meter.TimeSignature 4/4>
    {0.0} <music21.note.Rest 3.75ql>
    {3.5} <music21.tempo.MetronomeMark Quarter=77>
    {3.75} <music21.chord.Chord B3>
{4.0} <music21.stream.Measure 2 offset=4.0>
    {0.0} <music21.chord.Chord G2 D3 B3>
    {1.0} <music21.chord.Chord B3>
    {1.25} <music21.chord.Chord A3>
    {1.5} <music21.chord.Chord G3>
    {1.75} <music21.chord.Chord F#3>
    {2.0} <music21.chord.Chord G3>
    {2.25} <music21.chord.Chord D3>
    {2.5} <music21.chord.Chord E3>
    {2.75} <music21.chord.Chord F#3>
    {3.0} <music21.chord.Chord G3>
    {3.25} <music21.chord.Chord A3>
    {3.5} <music21.chord.Chord B3>
  

In [40]:
if PARSE_MIDI_FILES:
    notes, durations= parse_midi_files(file_list, parser, SEQ_LEN + 1, PARSED_DATA_PATH)
else:
    notes, durations= load_parsed_files()

1 Parsing ./datasets/bach-cello/cs1-1pre.mid
658 notes parsed
2 Parsing ./datasets/bach-cello/cs1-2all.mid


  notes, durations= parse_midi_files(file_list, parser, SEQ_LEN + 1, PARSED_DATA_PATH)


1579 notes parsed
3 Parsing ./datasets/bach-cello/cs1-3cou.mid
2399 notes parsed
4 Parsing ./datasets/bach-cello/cs1-4sar.mid
2662 notes parsed
5 Parsing ./datasets/bach-cello/cs1-5men.mid
3309 notes parsed
6 Parsing ./datasets/bach-cello/cs1-6gig.mid
3735 notes parsed
7 Parsing ./datasets/bach-cello/cs2-1pre.mid
4373 notes parsed
8 Parsing ./datasets/bach-cello/cs2-2all.mid
5066 notes parsed
9 Parsing ./datasets/bach-cello/cs2-3cou.mid
5807 notes parsed
10 Parsing ./datasets/bach-cello/cs2-4sar.mid
6144 notes parsed
11 Parsing ./datasets/bach-cello/cs2-5men.mid
6671 notes parsed
12 Parsing ./datasets/bach-cello/cs2-6gig.mid
7406 notes parsed
13 Parsing ./datasets/bach-cello/cs3-1pre.mid
8387 notes parsed
14 Parsing ./datasets/bach-cello/cs3-2all.mid
9124 notes parsed
15 Parsing ./datasets/bach-cello/cs3-3cou.mid
10113 notes parsed
16 Parsing ./datasets/bach-cello/cs3-4sar.mid
10454 notes parsed
17 Parsing ./datasets/bach-cello/cs3-5bou.mid
11335 notes parsed
18 Parsing ./datasets/bach

In [42]:
example_notes= notes[658]
example_durations= durations[658]
print('\nNotes string\n', example_notes, '...')
print('\nDurations string\n', example_durations, '...')


Notes string
 START G:major 4/4TS rest B3 B3 B3 A3 G3 F#3 G3 D3 E3 F#3 G3 A3 B3 C4 D4 B3 G3 F#3 G3 E3 D3 C3 B2 C3 D3 E3 F#3 G3 A3 B3 C4 A3 G3 F#3 G3 E3 F#3 G3 A2 D3 F#3 G3 A3 B3 C4 A3 B3 ...

Durations string
 0.0 0.0 0.0 3.75 0.25 1.0 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 ...


## TOKENIZE DATA

In [46]:
notes_seq_ds, notes_vectorize_layer, notes_vocab= create_dataset(notes)
durations_seq_ds, durations_vectorize_layer, durations_vocab= create_dataset(durations)
seq_ds= tf.data.Dataset.zip((notes_seq_ds, durations_seq_ds))

2024-12-05 22:36:10.494056: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-12-05 22:36:11.212523: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [50]:
example_tokenised_notes= notes_vectorize_layer(example_notes)
example_tokenised_durations= durations_vectorize_layer(example_durations)

print('{:10} {:10}'.format('note token', 'duration token'))
for i, (note_int, duration_int) in enumerate(
    zip(
        example_tokenised_notes.numpy()[:11],
        example_tokenised_durations.numpy()[:11]
    )):
    print(f'{note_int:10} {duration_int:10}')

note token duration token
        37          9
        51          9
        42          9
        33         18
         9          2
         9          4
         9          2
         3          2
         2          2
        12          2
         2          2


In [51]:
notes_vocab_size= len(notes_vocab)
durations_vocab_size= len(durations_vocab)

print(f'\nNOTES_VOCAB: length= {len(notes_vocab)}')

for i, note in enumerate(notes_vocab[:10]):
    print(f'{i}: {note}')

print(f'\nDURATIONS_VOCAB: length= {len(durations_vocab)}')

for i, note in enumerate(durations_vocab[:10]):
    print(f'{i}: {note}')


NOTES_VOCAB: length= 59
0: 
1: [UNK]
2: G3
3: A3
4: D3
5: F3
6: C4
7: D4
8: E3
9: B3

DURATIONS_VOCAB: length= 24
0: 
1: [UNK]
2: 0.25
3: 0.5
4: 1.0
5: 1/3
6: 0.75
7: 1/12
8: 1.5
9: 0.0


## TRAINING DATASET

In [53]:
# seq_ds.map(prepare_inputs):
# Converts the original sequential dataset (seq_ds) into (input, target) pairs for each batch.

# .repeat(DATASET_REPETITIONS):
# Repeats the entire dataset DATASET_REPETITIONS times to provide multiple epochs worth of training data.

ds= seq_ds.map(prepare_inputs).repeat(DATASET_REPETTIOTIONS)


In [54]:
example_input_output= ds.take(1).get_single_element()
print(example_input_output)

((<tf.Tensor: shape=(256, 50), dtype=int64, numpy=
array([[ 2,  5, 31, ..., 14, 11,  2],
       [ 5, 31,  2, ..., 11,  2, 10],
       [31,  2,  5, ...,  2, 10,  4],
       ...,
       [11, 14,  2, ...,  6, 16,  6],
       [14,  2,  5, ..., 16,  6, 11],
       [ 2,  5,  2, ...,  6, 11, 14]])>, <tf.Tensor: shape=(256, 50), dtype=int64, numpy=
array([[3, 3, 3, ..., 3, 3, 3],
       [3, 3, 3, ..., 3, 3, 3],
       [3, 3, 3, ..., 3, 3, 3],
       ...,
       [3, 3, 3, ..., 3, 3, 3],
       [3, 3, 3, ..., 3, 3, 3],
       [3, 3, 3, ..., 3, 3, 3]])>), (<tf.Tensor: shape=(256, 50), dtype=int64, numpy=
array([[ 5, 31,  2, ..., 11,  2, 10],
       [31,  2,  5, ...,  2, 10,  4],
       [ 2,  5, 10, ..., 10,  4, 13],
       ...,
       [14,  2,  5, ..., 16,  6, 11],
       [ 2,  5,  2, ...,  6, 11, 14],
       [ 5,  2, 11, ..., 11, 14,  2]])>, <tf.Tensor: shape=(256, 50), dtype=int64, numpy=
array([[3, 3, 3, ..., 3, 3, 3],
       [3, 3, 3, ..., 3, 3, 3],
       [3, 3, 3, ..., 3, 3, 3],
       ...,