# Dataset and Preprocessing

In [1]:
# get chorale dataset
# https://github.com/ageron/handson-ml2/blob/master/datasets/jsb_chorales/README.md

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
device='mps' if torch.backends.mps.is_available() else 'cpu'

In [3]:
# preprocess the chorales
import os
import csv

note_range = [88, 0]
voice_ranges = [[88, 0] for _ in range(4)]

def parse_score_dir(folder_path):
    scores = []

    # Loop through all files in the directory
    for filename in os.listdir(folder_path):
        if filename.endswith('.csv'):
            file_path = os.path.join(folder_path, filename)
            #print(f"Contents of {filename}:")

            # Open the CSV file
            with open(file_path, 'r') as csv_file:
                csv_reader = csv.reader(csv_file)
                next(csv_reader) # skip header

                voices = [[] for _ in range(4)]
                for row in csv_reader:
                    for i in range(4):
                        note = int(row[i])

                        voices[i].append(note)

                        if note != 0:
                            note_range[0] = min(note_range[0], note)
                            note_range[1] = max(note_range[1], note)

                            voice_ranges[i][0] = min(voice_ranges[i][0], note)
                            voice_ranges[i][1] = max(voice_ranges[i][1], note)

                scores.append([filename, voices])

    return scores

test = parse_score_dir("test")
train = parse_score_dir("train")
valid = parse_score_dir("valid")

print(note_range)
print(voice_ranges)

[36, 81]
[[60, 81], [52, 74], [46, 69], [36, 66]]


In [4]:
import math
from secrets import randbelow
import torch

def rand_num(min, max_excl):
    return randbelow(int(max_excl - min)) + min

def split_scores(scores, repeats=20):
    res = []
    for score in scores:
        filename, voices = score
        for r in range(repeats):
            num_beats = len(voices[0]) / 4
            src_seq_length = rand_num(16, 21)

            # pick random index for training seq
            i = randbelow(int(num_beats - src_seq_length + 1))
            src_voices = [v[i*4:(i+src_seq_length)*4] for v in voices]

            # pick index for tgt
            tgt_voices = None
            if i > (num_beats - (i + src_seq_length)):
                # tgt at beginning
                if i <= 17:
                    continue
                tgt_seq_length = min(20, rand_num(17, i))
                tgt_i = randbelow(i - tgt_seq_length + 1)
                tgt_voices = [v[tgt_i*4:(tgt_i+tgt_seq_length)*4] for v in voices]
            else:
                # tgt at end
                if (num_beats - (i + src_seq_length)) <= 17:
                    continue
                tgt_seq_length = min(20, rand_num(17, (num_beats - (i + src_seq_length))))
                tgt_i = rand_num(i + src_seq_length + 1, num_beats - tgt_seq_length + 1)
                tgt_voices = [v[tgt_i*4:(tgt_i+tgt_seq_length)*4] for v in voices]

            res.append([filename, src_voices, tgt_voices])

    return res

split_train = split_scores(train, 40)
split_test = split_scores(test, 15)
split_valid = split_scores(valid)

# max_src = 0
# max_tgt = 0
# for s in split_train:
#     max_src = max(len(s[1][0]), max_src)
#     max_tgt = max(len(s[2][0]), max_tgt)

# print(max_src)
# print(max_tgt)

print(f"Num in train: {len(split_train)}")
print(f"Num in valid: {len(split_valid)}")
print(f"Num in test: {len(split_test)}")

Num in train: 7652
Num in valid: 1216
Num in test: 954


In [5]:
from torch.utils.data import DataLoader, TensorDataset
import torch

PAD_TOKEN_ID = 0
EOS_TOKEN_ID = 1
NO_NOTE_TOKEN_ID = 2
NOTE_OFFSET = note_range[0] - 3
VOCAB_SIZE = 3 + note_range[1] - note_range[0] + 1

def to_dataset(split_scores, max_src_seq_len=324, max_tgt_seq_len=324):
    src_data = []
    tgt_data = []
    for score in split_scores:
        filename, src_voices, tgt_voices = score

        def note_to_id(note):
            if note == 0:
                return NO_NOTE_TOKEN_ID
            return note - NOTE_OFFSET

        def voices_to_src_seq(voices):
            seq = torch.zeros(max_src_seq_len, dtype=torch.int64)
            for i in range(len(voices[0])):
                seq[i*4] = note_to_id(voices[0][i])
                seq[i*4 + 1] = note_to_id(voices[1][i])
                seq[i*4 + 2] = note_to_id(voices[2][i])
                seq[i*4 + 3] = note_to_id(voices[3][i])
            return seq

        def voices_to_tgt_seq(voices):
            seq = torch.zeros(max_tgt_seq_len, dtype=torch.int64)
            for i in range(len(voices[0])):
                seq[i*4] = note_to_id(voices[0][i])
                seq[i*4 + 1] = note_to_id(voices[1][i])
                seq[i*4 + 2] = note_to_id(voices[2][i])
                seq[i*4 + 3] = note_to_id(voices[3][i])
            seq[(i+1)*4] = EOS_TOKEN_ID
            seq[(i+1)*4 + 1] = EOS_TOKEN_ID
            seq[(i+1)*4 + 2] = EOS_TOKEN_ID
            seq[(i+1)*4 + 3] = EOS_TOKEN_ID
            return seq

        src_data.append(voices_to_src_seq(src_voices))
        tgt_data.append(voices_to_tgt_seq(tgt_voices))

    dataset = TensorDataset(torch.stack(src_data).to(device), torch.stack(tgt_data).to(device))
    return dataset

train_dataset = to_dataset(split_train)
test_dataset = to_dataset(split_test)
# valid_dataset = to_dataset(split_valid)

# Score Visualization

In [6]:
import math
from pathlib import Path
import musicscore
from musicscore import *
import uuid

gen_id = lambda: f"a{str(uuid.uuid4()).replace('-', '')}"

def visualize_score(score, name):
    title, content = score
    newScore = Score(title=title)

    soprano = newScore.add_child(Part(gen_id(), name='Soprano'))
    alto = newScore.add_child(Part(gen_id(), name='Alto'))
    tenor = newScore.add_child(Part(gen_id(), name='Tenor'))
    bass = newScore.add_child(Part(gen_id(), name='Bass'))

    parts = [soprano, alto, tenor, bass]
    clefs = [TrebleClef(), TrebleClef(), BassClef(), BassClef()]

    for part_index in range(4):
        part_content = content[part_index]
        part = parts[part_index]
        clef = clefs[part_index]

        if (len(part_content) % 16 != 0 and False):
            # add pickup measure
            pickuptime = musicscore.time.Time()
            pickuptime.actual_signatures = [1, 4]

            measure = part.add_child(Measure(number=1, time=pickuptime))
            staff = measure.add_child(Staff(clef=clef))

            normaltime = musicscore.time.Time()
            measure = part.add_child(Measure(number=2, time=normaltime))
        else:
            measure = part.add_child(Measure(number=1))
            staff = measure.add_child(Staff(clef=clef))

        note_index = 0
        while note_index < len(part_content):
            note = part_content[note_index]
            duration = .25
            while (note_index != len(part_content) - 1 and (note_index + 1) % 4 != 0 and note == part_content[note_index + 1]):
                duration += .25
                note_index += 1

            accidental = musicscore.accidental.Accidental("sharp")

            #check if next note is descending
            check_index = note_index + 1
            while True:
                if check_index == len(part_content):
                    break;

                check_note = part_content[check_index]
                if (check_note != note):
                    if check_note < note and note - check_note < 3:
                        accidental.mode = "flat"
                    break;

                check_index += 1


            midi = musicscore.midi.Midi(note, accidental)
            chord = Chord(midi, duration)
            part.add_chord(chord)
            note_index += 1

    xml_path = Path(name).with_suffix('.xml')
    newScore.export_xml(xml_path)

# Neural Network

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, n=10000):
        super(PositionalEncoding, self).__init__()
        # pe = torch.zeros(max_len, d_model)
        # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # pe[:, 0::2] = torch.sin(position * div_term)
        # pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        # self.register_buffer('pe', pe)

        # Generate position encoding based on the given sequence length, hidden size, and frequency (n)

        horizontal_pe = torch.zeros(max_len, d_model)

        for pos in range(max_len):
            for i in range(d_model // 2):
                horizontal_pe[pos, 2*i] = torch.sin(torch.tensor(pos / (n**(2*i / d_model))))
                horizontal_pe[pos, 2*i+1] = torch.cos(torch.tensor(pos / (n**(2*i / d_model))))
        self.register_buffer('horizontal_pe', horizontal_pe)

    def forward(self, x):
        # print(x.shape)
        # print(self.horizontal_pe[:x.size(1), :].shape)
        return x + self.horizontal_pe[:x.size(1), :]

In [8]:
class PositionalEncoding2D(nn.Module):
    def __init__(self, d_model, max_len=320, n=10000):
        super(PositionalEncoding2D, self).__init__()
        self.max_len = max_len
        self.d_model = d_model
        self.n = n

        # Initialize positional encodings
        pe = torch.zeros(max_len * 4, d_model)
        for col in range(max_len):
            for row in range(4):
                pos = (row * 4) + col
                for i in range(d_model // 2):
                    pe[pos, 2*i] = torch.sin(torch.tensor(
                        (col / (n**(2*i / d_model))) +
                        (row / (n**(2*i / d_model)))))
                    pe[pos, 2*i+1] = torch.cos(torch.tensor(
                        (col / (n**(2*i / d_model))) +
                        (row / (n**(2*i / d_model)))))
        self.register_buffer('pe', pe)

    def forward(self, x):
        batch_size, seq_len, emb_dim = x.size()

        return x + self.pe[:seq_len, :]

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)

    def attention(self, query, key, value, mask=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        # if mask is not None:
        #     scores = scores.masked_fill(mask == 1, -1e9)
        if mask is not None:
            scores += (mask * -1e9)
        attention_weights = torch.nn.functional.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, value)
        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections
        query = self.linear_q(query)
        key = self.linear_k(key)
        value = self.linear_v(value)

        # Splitting heads
        query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Attention
        output, attention_weights = self.attention(query, key, value, mask=mask)

        # Concatenation of heads
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)

        # Linear projection
        output = self.linear_out(output)

        # print(output.shape)
        return output, attention_weights

In [10]:
class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionwiseFeedforward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = torch.nn.functional.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [11]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedforward(d_model, d_ff)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Multi-head self-attention
        attn_output, _ = self.self_attn(x, x, x, mask=mask)
        # Add & Norm
        x = self.layer_norm1(x + self.dropout(attn_output))
        # Position-wise feedforward
        ffn_output = self.ffn(x)
        # Add & Norm
        x = self.layer_norm2(x + self.dropout(ffn_output))
        return x

In [12]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_layers, num_heads, d_ff, dropout=0.1):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.layer_norm(x)

In [13]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.masked_attn_head = MultiHeadAttention(d_model, num_heads)
        self.enc_dec_attn_head = MultiHeadAttention(d_model, num_heads)
        self.positionwise_ffn = PositionwiseFeedforward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        # Masked multi-head attention
        attn_masked, _ = self.masked_attn_head(x, x, x, mask=tgt_mask)
        x = x + self.dropout(attn_masked)
        x = self.norm1(x)

        # Multi-head attention over encoder's output
        attn_enc_dec, _ = self.enc_dec_attn_head(x, enc_output, enc_output, mask=src_mask)
        x = x + self.dropout(attn_enc_dec)
        x = self.norm2(x)

        # Positionwise feedforward
        x = x + self.dropout(self.positionwise_ffn(x))
        x = self.norm3(x)

        return x

In [14]:
class Decoder(nn.Module):
    def __init__(self, d_model, num_layers, num_heads, d_ff, dropout=0.1):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return self.norm(x)

In [15]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, groups, dropout=0.1):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding2D(d_model)
        self.encoder = Encoder(d_model, num_layers, num_heads, d_ff, dropout)
        self.decoder = Decoder(d_model, num_layers, num_heads, d_ff, dropout)
        self.output_linear = nn.Linear(d_model, vocab_size)

    def create_padding_mask(self, inputs):
        # Create padding mask for inputs
        mask = torch.zeros(inputs.shape[0], inputs.shape[1]).to(device)
        mask = mask.masked_fill(inputs == 0, 1)
        mask = mask.view(inputs.shape[0], 1, 1, inputs.shape[1])
        return mask

    def create_lookahead_mask(self, inputs):
        # Create lookahead mask for inputs
        mask = torch.triu(torch.ones((inputs.shape[1], inputs.shape[1])), diagonal=1)
        return mask

    def forward(self, src_seq, tgt_seq):
        src_mask = self.create_padding_mask(src_seq).to(device)

        # Create padding mask and lookahead mask for decoder inputs
        padding_mask_dec = self.create_padding_mask(tgt_seq).to(device)
        lookahead_mask_dec = self.create_lookahead_mask(tgt_seq).to(device)
        # Combine padding mask and lookahead mask for decoder
        dec_mask = torch.max(padding_mask_dec, lookahead_mask_dec).to(device)

        # print("ENCODING")
        src_emb = self.embedding(src_seq)
        src_emb = self.positional_encoding(src_emb)
        enc_output = self.encoder(src_emb, src_mask)

        # print("DECODING")
        tgt_emb = self.embedding(tgt_seq)
        tgt_emb = self.positional_encoding(tgt_emb)
        dec_output = self.decoder(tgt_emb, enc_output, src_mask, dec_mask)

        output = self.output_linear(dec_output)

        # print("OUTPUT")
        # print(output.shape)
        return output

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

class CustomLoss(nn.Module):
    def __init__(self, voice_ranges, downbeat_weight=2, note_change_weight=2, voice_range_weight=4):
        super(CustomLoss, self).__init__()

        self.voice_range_weight = voice_range_weight
        self.downbeat_weight = downbeat_weight
        self.note_change_weight = note_change_weight

        # [[27, 48], [19, 41], [13, 36], [3, 33]]
        self.voice_ranges = copy.deepcopy(voice_ranges)
        for r in range(len(self.voice_ranges)):
            for i in range(len(self.voice_ranges[r])):
                self.voice_ranges[r][i] -= NOTE_OFFSET

        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def cross_entropy(self, predictions, tgt_seq, targets):
        # Cross Entropy Loss
        max_logits = predictions.max(dim=2, keepdim=True)[0]  # For numerical stability
        exp_logits = torch.exp(predictions - max_logits)
        softmax_probs = exp_logits / exp_logits.sum(dim=2, keepdim=True)

        gathered_probs = softmax_probs.gather(dim=2, index=targets.unsqueeze(2)).squeeze(2)
        nll = -torch.log(gathered_probs + 1e-9) 

        # downbeat bias
        seq_len = nll.size(1)
        down_beat_mask = torch.zeros(seq_len // 4, 4, dtype=torch.bool)
        down_beat_mask[::4, :] = True
        down_beat_mask = down_beat_mask.view(-1)
        down_beat_mask = torch.roll(down_beat_mask, shifts=-1)
        nll[:, down_beat_mask] *= self.downbeat_weight

        # note-change bias
        note_change_mask = torch.zeros_like(nll, dtype=torch.bool)
        note_change_mask[:, 4:] = tgt_seq[:, 4:] != tgt_seq[:, :-4]
        note_change_mask_clone = note_change_mask.clone()
        note_change_mask[:, :-1] = note_change_mask_clone[:, 1:]
        note_change_mask[:, -1] = False
        nll[note_change_mask] *= self.note_change_weight

        total_loss = nll.mean()

        return total_loss
    
    def voice_range_loss(self, predictions, targets):
        pred_tokens = predictions.argmax(dim=-1)
        not_note_mask = (pred_tokens <= 2)
        divisor = 1

        # soprano voice, notes too low
        sop_mask = torch.zeros_like(pred_tokens, dtype=torch.bool)
        sop_mask[:, 3::4] = True
        penalty_below_0 = (self.voice_ranges[0][0] - pred_tokens) * (pred_tokens < self.voice_ranges[0][0]) / divisor
        penalty_below_0 = penalty_below_0.masked_fill(not_note_mask, 0)
        penalty_below_0 = penalty_below_0.masked_fill(~sop_mask, 0)

        # # soprano voice, notes too high
        # penalty_above_0 = (pred_tokens - self.voice_ranges[0][1]) * (pred_tokens > self.voice_ranges[0][1]) / divisor
        # penalty_above_0 = penalty_above_0.masked_fill(not_note_mask, 0)
        # penalty_above_0 = penalty_above_0.masked_fill(~sop_mask, 0)


        # alto voice, notes too low
        alto_mask = torch.zeros_like(pred_tokens, dtype=torch.bool)
        alto_mask[:, 0::4] = True
        penalty_below_1 = (self.voice_ranges[1][0] - pred_tokens) * (pred_tokens < self.voice_ranges[1][0]) / divisor
        penalty_below_1 = penalty_below_1.masked_fill(not_note_mask, 0)
        penalty_below_1 = penalty_below_1.masked_fill(~alto_mask, 0)

        # alto voice, notes too high
        penalty_above_1 = (pred_tokens - self.voice_ranges[1][1]) * (pred_tokens > self.voice_ranges[1][1]) / divisor
        penalty_above_1 = penalty_above_1.masked_fill(not_note_mask, 0)
        penalty_above_1 = penalty_above_1.masked_fill(~alto_mask, 0)

        
        # tenor voice, notes too low
        tenor_mask = torch.zeros_like(pred_tokens, dtype=torch.bool)
        tenor_mask[:, 1::4] = True
        penalty_below_2 = (self.voice_ranges[2][0] - pred_tokens) * (pred_tokens < self.voice_ranges[2][0]) / divisor
        penalty_below_2 = penalty_below_2.masked_fill(not_note_mask, 0)
        penalty_below_2 = penalty_below_2.masked_fill(~tenor_mask, 0)

        # tenor voice, notes too high
        penalty_above_2 = (pred_tokens - self.voice_ranges[2][1]) * (pred_tokens > self.voice_ranges[2][1]) / divisor
        penalty_above_2 = penalty_above_2.masked_fill(not_note_mask, 0)
        penalty_above_2 = penalty_above_2.masked_fill(~tenor_mask, 0)


        # bass voice, notes too low
        bass_mask = torch.zeros_like(pred_tokens, dtype=torch.bool)
        bass_mask[:, 2::4] = True
        # penalty_below_3 = (self.voice_ranges[3][0] - pred_tokens) * (pred_tokens < self.voice_ranges[3][0]) / divisor
        # penalty_below_3 = penalty_below_3.masked_fill(not_note_mask, 0)
        # penalty_below_3 = penalty_below_3.masked_fill(~bass_mask, 0)

        # bass voice, notes too high
        penalty_above_3 = (pred_tokens - self.voice_ranges[3][1]) * (pred_tokens > self.voice_ranges[3][1]) / divisor
        penalty_above_3 = penalty_above_3.masked_fill(not_note_mask, 0)
        penalty_above_3 = penalty_above_3.masked_fill(~bass_mask, 0)


        # total
        total_penalty = (penalty_below_0 + penalty_below_1 + penalty_below_2 + penalty_above_1 + penalty_above_2 + penalty_above_3)
        mean_penalty = total_penalty.mean()
        return mean_penalty

    def forward(self, predictions, tgt_seq, targets):
        l1 = self.voice_range_weight * self.voice_range_loss(predictions, targets)
        l2 = self.cross_entropy(predictions, tgt_seq, targets)
        total_loss = l1 + l2
        return total_loss


In [17]:
torch.mps.set_per_process_memory_fraction(2.0)

PATH = "6-LossFunctionLarge.pt"

# INPUT FORMAT
# Step1_Voice1, Step1_Voice2, Step1_Voice3, Step1_Voice4, Step2..., EOS,
# TARGET_CLS_TOKEN_ID, Voice1, ... Voice 4, PAD, PAD, ...
#Alternative, target chord in final layer of decoder?

# Define some hyperparameters
d_model = 128
num_layers = 5
num_heads = 32
d_ff = 256
dropout = 0.1
batch_size = 64
voices = 4

# train_dataset = TensorDataset(src_data, tgt_data)

# Initialize the Transformer model
model = Transformer(VOCAB_SIZE, d_model, num_layers, num_heads, d_ff, voices, dropout).to(device)

# Define loss function and optimizer
criterion = CustomLoss(voice_ranges)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [18]:
# Evaluation
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
criterion = CustomLoss(voice_ranges)

def evaluate_model(model):
    model.eval()  # Switch to evaluation mode
    total_loss = 0
    with torch.no_grad():  # Disable gradient computation
        for src_seq, tgt_seq in test_loader:
            tgt_target = torch.cat((tgt_seq[:, 1:], torch.full_like(tgt_seq[:, :1], PAD_TOKEN_ID)), dim=1)
            output = model(src_seq, tgt_seq)
            loss = criterion(output, tgt_seq, tgt_target)
            total_loss += loss.item()

    return total_loss / len(test_loader)

In [19]:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [20]:
import gc
gc.collect()
torch.mps.empty_cache() 

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
criterion = CustomLoss(voice_ranges)

num_epochs = 1

zero_tensor = torch.tensor([99]).to(device)
def print_seq(seq, print_len=16):
    seq = torch.cat((zero_tensor, seq))
    soprano = seq[0::4]
    alto = seq[1::4]
    tenor = seq[2::4]
    bass = seq[3::4]

    voices = [soprano, alto, tenor, bass]
    for v in voices:
        print_str = ""
        for i, val in enumerate(map(str, v.tolist()[:print_len])):
            if i % 4 == 0:
                print_str += "|"
            print_str += val + " "
        print(print_str)

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}" + "=" * 45)
    total_loss = 0
    num_batches = len(train_loader)
    for batch_idx, (src_seq, tgt_seq) in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()

        tgt_target = torch.cat((tgt_seq[:, 1:], torch.full_like(tgt_seq[:, :1], PAD_TOKEN_ID)), dim=1)

        output = model(src_seq, tgt_seq)

        if total_loss == 0:
            output_indexes = torch.argmax(output, dim=-1)
            print("Generated" + "-" * 43)
            print_seq(output_indexes[0])
            print("Target" + "-" * 46)
            print_seq(tgt_target[0])

        loss = criterion(output, tgt_seq, tgt_target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        print(f"\rBatch [{batch_idx + 1}/{num_batches}] Loss: {loss.item():.4f}", end='', flush=True)

        if ((batch_idx+1) % 15 == 0):
            print(f" | Test Loss: {evaluate_model(model)}")
            output_indexes = torch.argmax(output, dim=-1)
            print("\nGenerated" + "-" * 43)
            print_seq(output_indexes[0])
            print("Target" + "-" * 46)
            print_seq(tgt_target[0])

        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        torch.save(checkpoint, PATH)

        gc.collect()
        torch.mps.empty_cache() 
    
    if (epoch+1) % 1 == 0:
        print('Epoch {} Loss: {:.4f}'.format(epoch+1, total_loss / len(train_loader)))
        # print(f"Test Loss: {evaluate_model(model)}")


Generated-------------------------------------------
|99 36 37 36 |37 37 37 37 |39 39 39 39 |41 41 41 41 
|32 32 32 32 |32 32 32 32 |32 32 32 32 |32 32 32 32 
|27 27 27 27 |29 29 29 25 |24 24 24 24 |25 25 25 25 
|20 20 20 18 |17 17 17 17 |20 20 20 20 |13 13 13 13 
Target----------------------------------------------
|99 36 36 36 |37 37 37 37 |39 39 39 39 |41 41 41 41 
|32 32 32 32 |32 32 32 32 |32 32 32 32 |32 32 32 32 
|27 27 27 27 |25 25 25 25 |25 25 24 24 |25 25 25 25 
|20 20 18 18 |17 17 13 13 |20 20 20 20 |13 13 13 13 
Batch [15/120] Loss: 0.8385 | Test Loss: 2.0354145765304565

Generated-------------------------------------------
|99 41 41 41 |41 41 41 41 |43 39 39 39 |39 36 36 36 
|36 36 36 36 |36 34 34 34 |34 34 34 34 |32 32 32 32 
|32 32 32 29 |29 29 29 29 |31 31 31 31 |27 20 27 24 
|20 20 20 20 |22 22 22 22 |15 15 15 15 |20 20 20 20 
Target----------------------------------------------
|99 41 41 41 |41 41 41 41 |39 39 39 39 |36 36 36 36 
|36 36 36 36 |34 34 34 34 |34 34 34 3

In [21]:
print(f"Test Loss: {evaluate_model(model)}")

Test Loss: 2.039167372385661


In [24]:
# hand-evaluation

# temp=.5 (.4-.7) best
def generate_next_tokens(model, src_seq, tgt_prefix, max_len=320, temperature=.3):
    model.eval()
    with torch.no_grad():
        src_seq = src_seq.unsqueeze(0).to(device)
        tgt_seq = tgt_prefix.unsqueeze(0).to(device)
        for i in range(max_len):
            output = model(src_seq, tgt_seq)
            output_last_token = output[:, -1, :]  # Take the last token of the output
            # next_token = torch.argmax(output_last_token, dim=-1).unsqueeze(1)
            probabilities = torch.nn.functional.softmax(output_last_token / temperature, dim=-1)
            next_token = torch.multinomial(probabilities, 1)  # Sample the next token
            tgt_seq = torch.cat([tgt_seq, next_token], dim=-1)
            if next_token == EOS_TOKEN_ID:  # Add a condition to stop generation
                break

    return tgt_seq.squeeze(0)

def to_score_form(seq, name="Unknown"):
    def id_to_note(id):
        if id == NO_NOTE_TOKEN_ID or id == PAD_TOKEN_ID or id == EOS_TOKEN_ID:
            return 0
        return id + NOTE_OFFSET

    seq = seq.tolist()
    score = [[] for _ in range(4)]
    for i in range(len(seq)):
        score[i % 4].append(id_to_note(seq[i]))
    return [name, score]

gc.collect()
torch.mps.empty_cache() 

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
for src_seq, tgt_seq in test_loader:
    print("Generating Chorale")
    res = generate_next_tokens(model, src_seq[0], tgt_seq[0][:4])
    print("Generating Scores")
    score = to_score_form(res, "Generated Chorale")
    src_score = to_score_form(src_seq[0], "Source Chorale")
    tgt_score = to_score_form(tgt_seq[0], "Target Chorale")
    print("Exporting Scores")
    visualize_score(src_score, "6-LossFunctionLarge-Src")
    visualize_score(tgt_score, "6-LossFunctionLarge-Tgt")
    visualize_score(score, "6-LossFunctionLarge-Gen")

    gc.collect()
    torch.mps.empty_cache() 
    break

Generating Chorale
Generating Scores
Exporting Scores


In [28]:
del model
torch.mps.empty_cache()