# Dataset and Preprocessing

In [1]:
# get chorale dataset
# https://github.com/ageron/handson-ml2/blob/master/datasets/jsb_chorales/README.md

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
device='mps' if torch.backends.mps.is_available() else 'cpu'

In [3]:
# preprocess the chorales
import os
import csv

note_range = [88, 0]

def parse_score_dir(folder_path):
    scores = []

    # Loop through all files in the directory
    for filename in os.listdir(folder_path):
        if filename.endswith('.csv'):
            file_path = os.path.join(folder_path, filename)
            #print(f"Contents of {filename}:")

            # Open the CSV file
            with open(file_path, 'r') as csv_file:
                csv_reader = csv.reader(csv_file)
                next(csv_reader) # skip header

                voices = [[] for _ in range(4)]
                for row in csv_reader:
                    for i in range(4):
                        voices[i].append(int(row[i]))

                        if int(row[i]) != 0:
                            note_range[0] = min(note_range[0], int(row[i]))
                            note_range[1] = max(note_range[1], int(row[i]))

                scores.append([filename, voices])

    return scores

test = parse_score_dir("test")
train = parse_score_dir("train")
valid = parse_score_dir("valid")

print(note_range)

[36, 81]


In [4]:
import math
from secrets import randbelow
import torch

def rand_num(min, max_excl):
    return randbelow(int(max_excl - min)) + min

def split_scores(scores, repeats=5):
    res = []
    for score in scores:
        filename, voices = score
        for r in range(repeats):
            num_beats = len(voices[0]) / 4
            src_seq_length = rand_num(16, 21)

            # pick random index for training seq
            i = randbelow(int(num_beats - src_seq_length + 1))
            src_voices = [v[i*4:(i+src_seq_length)*4] for v in voices]

            # pick index for tgt
            tgt_voices = None
            if i > (num_beats - (i + src_seq_length)):
                # tgt at beginning
                tgt_seq_length = min(20, rand_num(2, i))
                tgt_i = randbelow(i - tgt_seq_length + 1)
                tgt_voices = [v[tgt_i*4:(tgt_i+tgt_seq_length)*4] for v in voices]
            else:
                # tgt at end
                tgt_seq_length = min(20, rand_num(2, (num_beats - (i + src_seq_length))))
                tgt_i = rand_num(i + src_seq_length + 1, num_beats - tgt_seq_length + 1)
                tgt_voices = [v[tgt_i*4:(tgt_i+tgt_seq_length)*4] for v in voices]

            res.append([filename, src_voices, tgt_voices])

    return res

split_train = split_scores(train)
split_test = split_scores(test)
split_valid = split_scores(valid)

# max_src = 0
# max_tgt = 0
# for s in split_train:
#     max_src = max(len(s[1][0]), max_src)
#     max_tgt = max(len(s[2][0]), max_tgt)

# print(max_src)
# print(max_tgt)

In [5]:
from torch.utils.data import DataLoader, TensorDataset
import torch

PAD_TOKEN_ID = 0
EOS_TOKEN_ID = 1
NO_NOTE_TOKEN_ID = 2
NOTE_OFFSET = note_range[0] - 3
VOCAB_SIZE = 3 + note_range[1] - note_range[0] + 1

def to_dataset(split_scores, max_src_seq_len=324, max_tgt_seq_len=324):
    src_data = []
    tgt_data = []
    for score in split_scores:
        filename, src_voices, tgt_voices = score

        def note_to_id(note):
            if note == 0:
                return NO_NOTE_TOKEN_ID
            return note - NOTE_OFFSET

        def voices_to_src_seq(voices):
            seq = torch.zeros(max_src_seq_len, dtype=torch.int64)
            for i in range(len(voices[0])):
                seq[i*4] = note_to_id(voices[0][i])
                seq[i*4 + 1] = note_to_id(voices[1][i])
                seq[i*4 + 2] = note_to_id(voices[2][i])
                seq[i*4 + 3] = note_to_id(voices[3][i])
            return seq

        def voices_to_tgt_seq(voices):
            seq = torch.zeros(max_tgt_seq_len, dtype=torch.int64)
            for i in range(len(voices[0])):
                seq[i*4] = note_to_id(voices[0][i])
                seq[i*4 + 1] = note_to_id(voices[1][i])
                seq[i*4 + 2] = note_to_id(voices[2][i])
                seq[i*4 + 3] = note_to_id(voices[3][i])
            seq[(i+1)*4] = EOS_TOKEN_ID
            return seq

        src_data.append(voices_to_src_seq(src_voices))
        tgt_data.append(voices_to_tgt_seq(tgt_voices))

    dataset = TensorDataset(torch.stack(src_data), torch.stack(tgt_data))
    return dataset

train_dataset = to_dataset(split_train)
test_dataset = to_dataset(split_test)
valid_dataset = to_dataset(split_valid)

# Score Visualization

In [6]:
import math
from pathlib import Path
import musicscore
from musicscore import *
import uuid

gen_id = lambda: f"a{str(uuid.uuid4()).replace('-', '')}"

def visualize_score(score, name):
    title, content = score
    newScore = Score(title=title)

    soprano = newScore.add_child(Part(gen_id(), name='Soprano'))
    alto = newScore.add_child(Part(gen_id(), name='Alto'))
    tenor = newScore.add_child(Part(gen_id(), name='Tenor'))
    bass = newScore.add_child(Part(gen_id(), name='Bass'))

    parts = [soprano, alto, tenor, bass]
    clefs = [TrebleClef(), TrebleClef(), BassClef(), BassClef()]

    for part_index in range(4):
        part_content = content[part_index]
        part = parts[part_index]
        clef = clefs[part_index]

        if (len(part_content) % 16 != 0):
            # add pickup measure
            pickuptime = musicscore.time.Time()
            pickuptime.actual_signatures = [1, 4]

            measure = part.add_child(Measure(number=1, time=pickuptime))
            staff = measure.add_child(Staff(clef=clef))

            normaltime = musicscore.time.Time()
            measure = part.add_child(Measure(number=2, time=normaltime))
        else:
            measure = part.add_child(Measure(number=1))
            staff = measure.add_child(Staff(clef=clef))

        note_index = 0
        while note_index < len(part_content):
            note = part_content[note_index]
            duration = .25
            while (note_index != len(part_content) - 1 and (note_index + 1) % 4 != 0 and note == part_content[note_index + 1]):
                duration += .25
                note_index += 1

            accidental = musicscore.accidental.Accidental("sharp")

            #check if next note is descending
            check_index = note_index + 1
            while True:
                if check_index == len(part_content):
                    break;

                check_note = part_content[check_index]
                if (check_note != note):
                    if check_note < note and note - check_note < 3:
                        accidental.mode = "flat"
                    break;

                check_index += 1


            midi = musicscore.midi.Midi(note, accidental)
            chord = Chord(midi, duration)
            part.add_chord(chord)
            note_index += 1

    xml_path = Path(name).with_suffix('.xml')
    newScore.export_xml(xml_path)

# Neural Network

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, n=10000):
        super(PositionalEncoding, self).__init__()
        # pe = torch.zeros(max_len, d_model)
        # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # pe[:, 0::2] = torch.sin(position * div_term)
        # pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        # self.register_buffer('pe', pe)



        # Generate position encoding based on the given sequence length, hidden size, and frequency (n)

        horizontal_pe = torch.zeros(max_len, d_model)

        for pos in range(max_len):
            for i in range(d_model // 2):
                horizontal_pe[pos, 2*i] = torch.sin(torch.tensor(pos / (n**(2*i / d_model))))
                horizontal_pe[pos, 2*i+1] = torch.cos(torch.tensor(pos / (n**(2*i / d_model))))
        self.register_buffer('horizontal_pe', horizontal_pe)

    def forward(self, x):
        # print(x.shape)
        # print(self.horizontal_pe[:x.size(1), :].shape)
        return x + self.horizontal_pe[:x.size(1), :]

In [8]:
class PositionalEncoding2D(nn.Module):
    def __init__(self, d_model, max_len=320, n=10000):
        super(PositionalEncoding2D, self).__init__()
        self.max_len = max_len
        self.d_model = d_model
        self.n = n

        # Initialize positional encodings
        pe = torch.zeros(max_len * 4, d_model)
        for col in range(max_len):
            for row in range(4):
                pos = (row * 4) + col
                for i in range(d_model // 2):
                    pe[pos, 2*i] = torch.sin(torch.tensor(
                        (col / (n**(2*i / d_model))) +
                        (row / (n**(2*i / d_model)))))
                    pe[pos, 2*i+1] = torch.cos(torch.tensor(
                        (col / (n**(2*i / d_model))) +
                        (row / (n**(2*i / d_model)))))
        self.register_buffer('pe', pe)

    def forward(self, x):
        batch_size, seq_len, emb_dim = x.size()

        return x + self.pe[:seq_len, :]

In [9]:
class WeightedPositionalEncoding(nn.Module):
    def __init__(self, d_model, horizontal_bias=.5, num_rows=4, max_len=5000, n=10000):
        super(WeightedPositionalEncoding, self).__init__()

        # Assuming max_len is the maximum number of columns
        self.num_rows = num_rows
        self.d_model = d_model
        self.horizontal_bias = horizontal_bias

        # Precompute positional encodings for rows and columns
        # Rows: fixed at 4
        vertical_pe = torch.zeros(num_rows, d_model)
        horizontal_pe = torch.zeros(max_len, d_model)

        for i in range(num_rows):
            for j in range(d_model // 2):
                vertical_pe[i, 2*j] = torch.sin(torch.tensor(i / (n**(2*j / d_model))))
                vertical_pe[i, 2*j+1] = torch.cos(torch.tensor(i / (n**((2*j+1) / d_model))))

        for pos in range(max_len):
            for i in range(d_model // 2):
                horizontal_pe[pos, 2*i] = torch.sin(torch.tensor(pos / (n**(2*i / d_model))))
                horizontal_pe[pos, 2*i+1] = torch.cos(torch.tensor(pos / (n**((2*i+1) / d_model))))

        self.register_buffer('vertical_pe', vertical_pe)
        self.register_buffer('horizontal_pe', horizontal_pe)

    def forward(self, x):
        # x shape: [batch_size, num_embeddings, d_model]
        num_columns = x.size(1) // self.num_rows
        batch_size = x.size(0)

        # Construct full positional encoding for each element in the grid
        full_pe = torch.zeros(batch_size, self.num_rows * num_columns, self.d_model, device=x.device)

        for row in range(self.num_rows):
            for col in range(num_columns):
                full_pe[:, row * num_columns + col, :] = ((1 - self.horizontal_bias) * self.vertical_pe[row]) + (self.horizontal_bias * self.horizontal_pe[col])

        return x + full_pe

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)

    def attention(self, query, key, value, mask=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        # if mask is not None:
        #     scores = scores.masked_fill(mask == 1, -1e9)
        if mask is not None:
            scores += (mask * -1e9)
        attention_weights = torch.nn.functional.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, value)
        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections
        query = self.linear_q(query)
        key = self.linear_k(key)
        value = self.linear_v(value)

        # Splitting heads
        query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Attention
        output, attention_weights = self.attention(query, key, value, mask=mask)

        # Concatenation of heads
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)

        # Linear projection
        output = self.linear_out(output)

        # print(output.shape)
        return output, attention_weights

In [11]:
class GroupedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, groups):
        super(GroupedMultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        assert num_heads % groups == 0, "num_heads must be divisible by groups"

        self.num_heads = num_heads
        self.groups = groups
        self.group_heads = num_heads // groups
        self.d_k = d_model // self.group_heads

        self.linear_q = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(groups)])
        self.linear_k = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(groups)])
        self.linear_v = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(groups)])
        self.linear_out = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(groups)])

    def attention(self, query, key, value, mask=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        print(scores.shape)
        print(mask.shape)
        if mask is not None:
            scores += (mask * -1e9)
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, value)
        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        print("ATTENTION")
        batch_size = query.size(0)
        outputs = []
        attention_weights_list = []

        for i in range(self.groups):
            # Linear projections for each group
            q_proj = self.linear_q[i](query)
            k_proj = self.linear_k[i](key)
            v_proj = self.linear_v[i](value)

            # Splitting heads for each group
            q_proj = q_proj.view(batch_size, -1, self.group_heads, self.d_k).transpose(1, 2)
            k_proj = k_proj.view(batch_size, -1, self.group_heads, self.d_k).transpose(1, 2)
            v_proj = v_proj.view(batch_size, -1, self.group_heads, self.d_k).transpose(1, 2)

            # print(q_proj.shape)
            # start_index = i * (query.size(1) // self.groups)
            # end_index = (i + 1) * (query.size(1) // self.groups)
            # q_proj = q_proj[:, :, start_index:end_index, :]
            q_proj = q_proj[:, :, i::4, :]

            # Attention for each group
            group_output, group_attention_weights = self.attention(q_proj, k_proj, v_proj, mask=mask)


            # Concatenation of heads for each group
            group_output = group_output.transpose(1, 2).contiguous().view(batch_size, -1, self.group_heads * self.d_k)
            # group_output = group_output.transpose(1, 2).contiguous().view(batch_size, self.group_heads, -1, self.d_k)

            # Linear projection for each group
            group_output = self.linear_out[i](group_output)

            outputs.append(group_output)
            attention_weights_list.append(group_attention_weights)

        return outputs, attention_weights_list

In [12]:
class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionwiseFeedforward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = torch.nn.functional.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [13]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedforward(d_model, d_ff)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Multi-head self-attention
        attn_output, _ = self.self_attn(x, x, x, mask=mask)
        # Add & Norm
        x = self.layer_norm1(x + self.dropout(attn_output))
        # Position-wise feedforward
        ffn_output = self.ffn(x)
        # Add & Norm
        x = self.layer_norm2(x + self.dropout(ffn_output))
        return x

In [14]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_layers, num_heads, d_ff, dropout=0.1):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.layer_norm(x)

In [15]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.masked_attn_head = MultiHeadAttention(d_model, num_heads)
        self.enc_dec_attn_head = MultiHeadAttention(d_model, num_heads)
        self.positionwise_ffn = PositionwiseFeedforward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        # Masked multi-head attention
        attn_masked, _ = self.masked_attn_head(x, x, x, mask=tgt_mask)
        x = x + self.dropout(attn_masked)
        x = self.norm1(x)

        # Multi-head attention over encoder's output
        attn_enc_dec, _ = self.enc_dec_attn_head(x, enc_output, enc_output, mask=src_mask)
        x = x + self.dropout(attn_enc_dec)
        x = self.norm2(x)

        # Positionwise feedforward
        x = x + self.dropout(self.positionwise_ffn(x))
        x = self.norm3(x)

        return x

In [16]:
class GroupedDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, groups, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        # Grouped multi-head attention for the final layer
        self.grouped_attn = GroupedMultiHeadAttention(d_model, num_heads, groups)
        self.grouped_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(groups)])
        self.grouped_dropout = nn.Dropout(dropout)

        self.positionwise_ffn = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff) for _ in range(groups)])
        self.norm2 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(groups)])
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, memory, src_mask, tgt_mask):
        # Standard self-attention and add&norm
        x = self.norm1(x + self.dropout1(self.self_attn(x, x, x, tgt_mask)[0]))

        # Grouped attention; splitting x for each group
        grouped_outputs = self.grouped_attn(x, memory, memory, src_mask)[0]

        # Processing each group's output independently
        final_outputs = []
        for i, group_output in enumerate(grouped_outputs):
            group_output = self.grouped_norm[i](group_output + self.grouped_dropout(group_output))
            group_output = self.norm2[i](group_output + self.dropout2(self.positionwise_ffn[i](group_output)))
            final_outputs.append(group_output)

        return final_outputs  # Return the list of processed group outputs

In [17]:
class OldDecoder(nn.Module):
    def __init__(self, d_model, num_layers, num_heads, d_ff, dropout=0.1):
        super(OldDecoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return self.norm(x)

class Decoder(nn.Module):
    def __init__(self, d_model, num_layers, num_heads, d_ff, groups, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers-1)])
        self.final_layer = GroupedDecoderLayer(d_model, num_heads, d_ff, groups, dropout)

    def forward(self, tgt, memory, src_mask, tgt_mask):
        for layer in self.layers:
            tgt = layer(tgt, memory, src_mask, tgt_mask)
        return self.final_layer(tgt, memory, src_mask, tgt_mask)

In [18]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, groups, dropout=0.1):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding2D(d_model)
        self.encoder = Encoder(d_model, num_layers, num_heads, d_ff, dropout)
        self.decoder = Decoder(d_model, num_layers, num_heads, d_ff, groups, dropout)
        self.output_layers = nn.ModuleList([nn.Linear(d_model, vocab_size) for _ in range(groups)])

    def create_padding_mask(self, inputs):
        # Create padding mask for inputs
        mask = torch.zeros(inputs.shape[0], inputs.shape[1]).to(device)
        mask = mask.masked_fill(inputs == 0, 1)
        mask = mask.view(inputs.shape[0], 1, 1, inputs.shape[1])

    def create_lookahead_mask(self, inputs):
        # Create lookahead mask for inputs
        mask = torch.triu(torch.ones((inputs.shape[1], inputs.shape[1])), diagonal=1)
        print(inputs.shape)
        print(mask.shape)
        print(mask)
        return mask

    def forward(self, src_seq, tgt_seq):
        src_mask = self.create_padding_mask(src_seq)
        print(src_mask.shape)

        # Create padding mask and lookahead mask for decoder inputs
        padding_mask_dec = self.create_padding_mask(tgt_seq)
        lookahead_mask_dec = self.create_lookahead_mask(tgt_seq).to(device)
        # Combine padding mask and lookahead mask for decoder
        dec_mask = torch.max(padding_mask_dec, lookahead_mask_dec)

        print("ENCODING")
        src_emb = self.embedding(src_seq)
        src_emb = self.positional_encoding(src_emb)
        enc_output = self.encoder(src_emb, src_mask)

        print("DECODING")
        tgt_emb = self.embedding(tgt_seq)
        tgt_emb = self.positional_encoding(tgt_emb)
        group_outputs = self.decoder(tgt_emb, enc_output, src_mask, dec_mask)

        output = [output_layer(group_output) for output_layer, group_output in zip(self.output_layers, group_outputs)]

        print("OUTPUT")
        print(output[0].shape)
        return output


class OldTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, groups, dropout=0.1):
        super(OldTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding2D(d_model)
        self.encoder = Encoder(d_model, num_layers, num_heads, d_ff, dropout)
        self.decoder = OldDecoder(d_model, num_layers, num_heads, d_ff, dropout)
        self.output_linear = nn.Linear(d_model, vocab_size)

    def create_padding_mask(self, inputs):
        # Create padding mask for inputs
        mask = torch.zeros(inputs.shape[0], inputs.shape[1]).to(device)
        mask = mask.masked_fill(inputs == 0, 1)
        mask = mask.view(inputs.shape[0], 1, 1, inputs.shape[1])
        return mask

    def create_lookahead_mask(self, inputs):
        # Create lookahead mask for inputs
        mask = torch.triu(torch.ones((inputs.shape[1], inputs.shape[1])), diagonal=1)
        return mask

    def forward(self, src_seq, tgt_seq):
        src_mask = self.create_padding_mask(src_seq).to(device)

        # Create padding mask and lookahead mask for decoder inputs
        padding_mask_dec = self.create_padding_mask(tgt_seq).to(device)
        lookahead_mask_dec = self.create_lookahead_mask(tgt_seq).to(device)
        # Combine padding mask and lookahead mask for decoder
        dec_mask = torch.max(padding_mask_dec, lookahead_mask_dec).to(device)

        # print("ENCODING")
        src_emb = self.embedding(src_seq)
        src_emb = self.positional_encoding(src_emb)
        enc_output = self.encoder(src_emb, src_mask)

        # print("DECODING")
        tgt_emb = self.embedding(tgt_seq)
        tgt_emb = self.positional_encoding(tgt_emb)
        dec_output = self.decoder(tgt_emb, enc_output, src_mask, dec_mask)

        output = self.output_linear(dec_output)

        # print("OUTPUT")
        # print(output.shape)
        return output

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

torch.mps.set_per_process_memory_fraction(2.0)

PATH = "1-PositionalEncodings.pt"

# INPUT FORMAT
# Step1_Voice1, Step1_Voice2, Step1_Voice3, Step1_Voice4, Step2..., EOS,
# TARGET_CLS_TOKEN_ID, Voice1, ... Voice 4, PAD, PAD, ...
#Alternative, target chord in final layer of decoder?

# Define some hyperparameters
d_model = 64
num_layers = 6
num_heads = 32
d_ff = 128
dropout = 0.1
batch_size = 32
voices = 4

# train_dataset = TensorDataset(src_data, tgt_data)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize the Transformer model
model = OldTransformer(VOCAB_SIZE, d_model, num_layers, num_heads, d_ff, voices, dropout).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [21]:
import gc

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

num_epochs = 5

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for src_seq, tgt_seq in train_loader:
        src_seq = src_seq.to(device)
        tgt_seq = tgt_seq.to(device)

        optimizer.zero_grad()

        tgt_target = torch.cat((tgt_seq[:, 1:], torch.full_like(tgt_seq[:, :1], PAD_TOKEN_ID)), dim=1)

        output = model(src_seq, tgt_seq)

        # output_indexes = torch.argmax(output, dim=-1)
        # print(output_indexes[0])
        # print(tgt_seq[0])
        # print(tgt_target[0])

        loss = criterion(output.view(-1, VOCAB_SIZE), tgt_target.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        gc.collect()
        torch.mps.empty_cache() 

    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, PATH)
    
    if (epoch+1) % 1 == 0:
        print('Epoch {} Loss: {:.4f}'.format(epoch+1, total_loss / len(train_loader)))
        
        # Print indexes of max values in the model output
        # print("Model output (indexes of max values):", output_indexes.view(-1)[:10])  # Example: print first 10 values
        output_indexes = torch.argmax(output, dim=-1)
        print(output_indexes[0][:32])
        # print(tgt_seq[0])
        print(tgt_target[0][:32])

        # print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")

# # Evaluation
# model.eval()
# with torch.no_grad():
#     test_output = model(test_src_data, test_tgt_data, test_src_mask)
#     test_loss = criterion(test_output.view(-1, tgt_vocab_size), test_tgt_data.view(-1))
#     print(f"Test Loss: {test_loss.item()}")

Epoch 1 Loss: 1.5504
tensor([39, 31, 17, 36, 36, 29, 17, 36, 36, 29, 17, 36, 36, 34, 17, 36, 34, 36,
        38, 39, 36, 36, 38, 39, 34, 36, 19, 39, 39, 36, 19, 39],
       device='mps:0')
tensor([36, 29, 17, 44, 36, 29, 17, 44, 38, 29, 17, 44, 38, 29, 17, 43, 39, 22,
        19, 43, 39, 22, 19, 43, 39, 22, 20, 43, 39, 22, 20, 41],
       device='mps:0')
Epoch 2 Loss: 1.4096
tensor([33, 22, 22, 33, 29, 26, 19, 38, 29, 26, 19, 38, 31, 26, 19, 38, 31, 26,
        19, 38, 31, 24, 17, 38, 31, 26, 39, 39, 31, 24, 17, 39],
       device='mps:0')
tensor([29, 26, 11, 38, 29, 26, 11, 38, 31, 26, 11, 38, 31, 26, 11, 39, 31, 24,
        12, 39, 31, 24, 12, 39, 31, 24, 12, 39, 31, 24, 12, 34],
       device='mps:0')
Epoch 3 Loss: 1.3111
tensor([32, 29, 27, 20, 32, 27, 24, 39, 32, 27, 24, 39, 32, 27, 24, 39, 32, 27,
        41, 13, 32, 27, 41, 41, 34, 25, 24, 13, 34, 27, 24, 13],
       device='mps:0')
tensor([32, 27, 24, 39, 32, 27, 24, 39, 32, 27, 24, 39, 32, 27, 24, 41, 32, 25,
        25, 41, 3

In [25]:
# hand-evaluation

# temp=.5 (.4-.7) best
def generate_next_tokens(model, src_seq, tgt_prefix, max_len=320, temperature=.6):
    model.eval()
    with torch.no_grad():
        src_seq = src_seq.unsqueeze(0).to(device)
        tgt_seq = tgt_prefix.unsqueeze(0).to(device)
        for i in range(max_len):
            output = model(src_seq, tgt_seq)
            output_last_token = output[:, -1, :]  # Take the last token of the output
            # next_token = torch.argmax(output_last_token, dim=-1).unsqueeze(1)
            probabilities = torch.nn.functional.softmax(output_last_token / temperature, dim=-1)
            next_token = torch.multinomial(probabilities, 1)  # Sample the next token
            tgt_seq = torch.cat([tgt_seq, next_token], dim=-1)
            if next_token == EOS_TOKEN_ID:  # Add a condition to stop generation
                break
    return tgt_seq.squeeze(0)

def to_score_form(seq, name="Unknown"):
    def id_to_note(id):
        if id == NO_NOTE_TOKEN_ID or id == PAD_TOKEN_ID:
            return 0
        return id + NOTE_OFFSET

    seq = seq.tolist()
    score = [[] for _ in range(4)]
    for i in range(len(seq)):
        score[i % 4].append(id_to_note(seq[i]))
    return [name, score]

gc.collect()
torch.mps.empty_cache() 

test_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
for src_seq, tgt_seq in test_loader:
    print("Generating Chorale")
    res = generate_next_tokens(model, src_seq[0], tgt_seq[0][:4])
    print("Generating Scores")
    score = to_score_form(res, "Generated Chorale")
    src_score = to_score_form(src_seq[0], "Source Chorale")
    tgt_score = to_score_form(tgt_seq[0], "Target Chorale")
    print("Exporting Scores")
    visualize_score(src_score, "2-LossFunction-Src")
    visualize_score(tgt_score, "2-LossFunction-Tgt")
    visualize_score(score, "2-LossFunction-Gen")

    gc.collect()
    torch.mps.empty_cache() 
    break

Generating Chorale
Generating Score
Exporting Score


In [None]:
del model
torch.mps.empty_cache()