In [523]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import re

In [524]:
data = '/Applications/ML projects/Failures/Song Lyrics/Dataset - 2/archive/csv'

In [525]:
def readLyrics(data):
  lyrics = []
  for CSV_FILE in tqdm(os.listdir(data)):
    CSV_PATH = os.path.join(data, CSV_FILE)
    df = pd.read_csv(CSV_PATH)
    df = df.dropna()
    dfLyrics = df['Lyric'].values
    for lyric in dfLyrics:
      lyrics.append(lyric)

  return lyrics

In [526]:
lyrics = readLyrics(data)

100%|██████████| 21/21 [00:00<00:00, 84.95it/s]


In [527]:
print(lyrics[0])

one one one one one   talkin' in my sleep at night makin' myself crazy out of my mind out of my mind wrote it down and read it out hopin' it would save me too many times too many times  refrain my love he makes me feel like nobody else nobody else but my love he doesn't love me so i tell myself i tell myself  pre one don't pick up the phone you know he's only callin' 'cause he's drunk and alone two don't let him in you'll have to kick him out again three don't be his friend you know you're gonna wake up in his bed in the morning and if you're under him you ain't gettin' over him   i got new rules i count 'em i got new rules i count 'em i gotta tell them to myself i got new rules i count 'em i gotta tell them to myself   i keep pushin' forwards but he keeps pullin' me backwards nowhere to turn no way nowhere to turn no now i'm standin' back from it i finally see the pattern i never learn i never learn  refrain but my love he doesn't love me so i tell myself i tell myself i do i do i do 

In [528]:
def preprocessLyrics(lyrics):
    for i in tqdm(range(len(lyrics))):
        lyric = lyrics[i]
        lyric = re.sub(r'[^a-zA-Z]\s', '', lyric)
        lyric = " ".join([word for word in lyric.split() if len(word) > 1])
        lyrics[i] = lyric

In [529]:
preprocessLyrics(lyrics)

  0%|          | 0/3422 [00:00<?, ?it/s]

100%|██████████| 3422/3422 [00:00<00:00, 9089.51it/s]


In [530]:
print(lyrics[0])

one one one one one talkinin my sleep at night makinmyself crazy out of my mind out of my mind wrote it down and read it out hopinit would save me too many times too many timesrefrain my love he makes me feel like nobody else nobody else but my love he doesn't love me so tell myself tell myselfpre one don't pick up the phone you know he's only callin'cause he's drunk and alone two don't let him in you'll have to kick him out again three don't be his friend you know you're gonna wake up in his bed in the morning and if you're under him you ain't gettinover him got new rules count 'em got new rules count 'em gotta tell them to myself got new rules count 'em gotta tell them to myself keep pushinforwards but he keeps pullinme backwards nowhere to turn no way nowhere to turn no now i'm standinback from it finally see the pattern never learn never learnrefrain but my love he doesn't love me so tell myself tell myself do do dopre one don't pick up the phone you know he's only callin'cause he'

In [531]:
def getVocab(lyrics):
    hashmap = {}
    for lyric in tqdm(lyrics):
        chars = list(lyric)
        for char in chars:
            if char not in hashmap:
                hashmap[char] = 1
            else:
                continue
    return list(hashmap.keys())

In [532]:
vocab = getVocab(lyrics)

  0%|          | 0/3422 [00:00<?, ?it/s]

100%|██████████| 3422/3422 [00:00<00:00, 11133.03it/s]


In [533]:
START_TOKEN = '<'
PAD_TOKEN = '$'
END_TOKEN = '>'

vocab.insert(0, START_TOKEN)
vocab.append(PAD_TOKEN)
vocab.append(END_TOKEN)

In [534]:
index_to_char = {k:v for k, v in enumerate(vocab)}
char_to_index = {v:k for k, v in enumerate(vocab)}

In [535]:
import statistics
lengths = [len(lyric) for lyric in lyrics]
mean_length = statistics.mean(lengths)
std_dev = statistics.stdev(lengths)
print("Mean: ", mean_length)
print("Standard Deviation: ", std_dev)

Mean:  1959.979544126242
Standard Deviation:  1488.036303883205


In [536]:
def filterLyrics(lyrics, min=1499, max=2501):
    filtered_lyrics = []
    for lyric in tqdm(lyrics):
        if len(lyric) > min and len(lyric) < max:
            filtered_lyrics.append(lyric)
    return filtered_lyrics

In [537]:
filtered_lyrics = filterLyrics(lyrics)

100%|██████████| 3422/3422 [00:00<00:00, 1454932.42it/s]


In [538]:
lengths = [len(lyric) for lyric in filtered_lyrics]
mean_length = statistics.mean(lengths)
std_dev = statistics.stdev(lengths)
print("Mean: ", mean_length)
print("Standard Deviation: ", std_dev)

Mean:  1924.6507064364207
Standard Deviation:  286.28491178941846


In [539]:
import random
random.shuffle(filtered_lyrics)
LYRICS = filtered_lyrics

In [540]:
len(LYRICS)

1274

In [542]:
LYRICS[0]

"baby got love for thee so deep inside of me don't know where to start yeah yeah love you more than anything but the words can't even touch what's in my heart no noprewhen try to explain it be sounding insane the words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words they only complicate it baby baby ohwoah ohwoah baby i'm so down for you no matter what you do real talk i'll be around yeah yeah ooh baby baby been feeling you before even knew what feelings were about oh babyprewhen try to explain it be sounding all crazy words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words

In [543]:
def inputOutput(LYRICS):
    input = []
    output = []
    for lyric in LYRICS:
        lyric = lyric.split(' ')
        length = len(lyric)
        split = int(0.7 * length)
        i = lyric[:split]
        i = " ".join(i)
        o = lyric[split:]
        o = " ".join(o)
        input.append(i)
        output.append(o)
    return input, output

In [544]:
input, output = inputOutput(LYRICS)

In [545]:
input[0]

"baby got love for thee so deep inside of me don't know where to start yeah yeah love you more than anything but the words can't even touch what's in my heart no noprewhen try to explain it be sounding insane the words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words they only complicate it baby baby ohwoah ohwoah baby i'm so down for you no matter what you do real talk i'll be around yeah yeah ooh baby baby been feeling you before even knew what feelings were about oh babyprewhen try to explain it be sounding all crazy words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words

In [546]:
output[0]

"that's why keep saying baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words they only complicate it baby baby ohwoah ohwoah baby baby baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words they only complicate it every time try to say it words they only complicate it every time try to say it words they only complicate it baby baby ohwoah ohwoah baby baby"

In [547]:
print(len(input), len(output))

1274 1274


In [548]:
PERCENTILE = 97
print( f"{PERCENTILE}th percentile length lyric: {np.percentile([len(x) for x in input], PERCENTILE)}" )

97th percentile length lyric: 1723.0


In [549]:
max_sequence_length = 1800

def is_valid_token(lyric, vocab):
    for token in list(set(lyric)):
        if token not in vocab:
            return False
    return True

def is_valid_length(lyric, max_sequence_length):
    return len(list(lyric)) < (max_sequence_length - 1)

In [550]:
valid_lyrics_indices = []
for i in tqdm(range(len(input))):
    ilyric, olyric = input[i], output[i]
    if is_valid_token(ilyric, vocab) and is_valid_length(ilyric, max_sequence_length) and is_valid_token(olyric, vocab) and is_valid_length(olyric, max_sequence_length):
        valid_lyrics_indices.append(i)

print(f"Number of lyrics: {len(input)}")
print(f"Number of valid lyrics: {len(valid_lyrics_indices)}")

  0%|          | 0/1274 [00:00<?, ?it/s]

100%|██████████| 1274/1274 [00:00<00:00, 10797.89it/s]

Number of lyrics: 1274
Number of valid lyrics: 1272





In [551]:
input_lyrics = [input[i] for i in valid_lyrics_indices]
output_lyrics = [output[i] for i in valid_lyrics_indices]

In [552]:
print(len(input_lyrics), len(input_lyrics))

1272 1272


In [553]:
train_split = 1000
train_ilyrics, train_olyrics = input_lyrics[:train_split], output_lyrics[:train_split]
test_ilyrics, test_olyrics = input_lyrics[train_split:], output_lyrics[train_split:]

In [554]:
print(train_ilyrics[0])
print(train_olyrics[0])

baby got love for thee so deep inside of me don't know where to start yeah yeah love you more than anything but the words can't even touch what's in my heart no noprewhen try to explain it be sounding insane the words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words they only complicate it baby baby ohwoah ohwoah baby i'm so down for you no matter what you do real talk i'll be around yeah yeah ooh baby baby been feeling you before even knew what feelings were about oh babyprewhen try to explain it be sounding all crazy words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words 

In [555]:
print(test_ilyrics[0])
print(test_olyrics[0])

baby will not pout baby will not cry 'cause got your love this christmas time when the snow's on the ground and it's freezing outside got your love this christmaspre on every list i've ever sent you're the gift i'd love the best so deck the halls and all the rest warm me up with your christmas love hey angel in the snow i'm under the mistletoe you are the one you're my very own christmas love tell santa i'm cool this year my present is standing right here thank god above for my very own christmas love yeah like beautiful tree you can light up the room but your kind of star can't be removed like beautiful carol get lost in your song and will forever sing alongpre on every list i've ever sent you're the gift i'd love the best so deck the halls and all the rest warm me up with your christmas love hey angel in the snow i'm under the mistletoe you are the one you're my very own christmas love tell santa i'm cool this year my present is standing right here thank god above for my very own chr

In [556]:
import numpy as np
import torch
from torch import nn
import math

In [557]:
def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [558]:
def scaled_dot_product(q, k, v, mask=None):
  d_k = q.size()[-1]
  scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled.permute(1, 0, 2, 3) + mask
    scaled = scaled.permute(1, 0, 2, 3)
  attention = nn.functional.softmax(scaled, dim=-1)
  values = torch.matmul(attention, v)
  return values, attention

In [559]:
class PositionalEncoding(nn.Module):
  def __init__(self, dimension_length, max_sequence_length):
    super().__init__()
    self.dimension_length = dimension_length
    self.max_sequence_length = max_sequence_length

  def forward(self):
    index = torch.arange(0, self.dimension_length, 2).float()
    denominator = torch.pow(10000, index / self.dimension_length)
    position = torch.arange(self.max_sequence_length).reshape(self.max_sequence_length, 1)
    even_positional_encoding = torch.sin(position / denominator)
    odd_positional_encoding = torch.cos(position / denominator)
    stacked = torch.stack([even_positional_encoding, odd_positional_encoding], dim=2)
    positional_encoding = torch.flatten(stacked, start_dim = 1, end_dim = 2)
    return positional_encoding

In [560]:
class SentenceEmbedding(nn.Module):
  def __init__(self, max_sequence_length, dimension_length, language_to_index, START_TOKEN, END_TOKEN, PAD_TOKEN):
    super().__init__()
    self.vocab_size = len(language_to_index)
    self.max_sequence_length = max_sequence_length
    self.embedding = nn.Embedding(self.vocab_size, dimension_length)
    self.language_to_index = language_to_index
    self.positional_encoder = PositionalEncoding(dimension_length, max_sequence_length)
    self.dropout = nn.Dropout(p=0.1)
    self.START_TOKEN = START_TOKEN
    self.END_TOKEN = END_TOKEN
    self.PAD_TOKEN = PAD_TOKEN

  def batch_tokenize(self, batch, start_token, end_token):
    def tokenize(sentence, start_token, end_token):
      sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
      if start_token:
        sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
      if end_token:
        sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
      for _ in range(len(sentence_word_indicies), self.max_sequence_length):
        sentence_word_indicies.append(self.language_to_index[self.PAD_TOKEN])
      return torch.tensor(sentence_word_indicies)

    tokenized = []
    for sentence_index in range(len(batch)):
      tokenized.append(tokenize(batch[sentence_index], start_token, end_token))
    tokenized = torch.stack(tokenized)
    return tokenized.to(get_device())

  def forward(self, x, start_token, end_token):
    x = self.batch_tokenize(x, start_token, end_token)
    x = self.embedding(x)
    pos_enc = self.positional_encoder().to(get_device())
    x = self.dropout(x + pos_enc)
    return x

In [561]:
class MultiHeadAttention(nn.Module):
  def __init__(self, dimension_length, num_heads):
    super().__init__()
    self.dimension_length = dimension_length
    self.num_heads = num_heads
    self.head_dim = self.dimension_length // self.num_heads
    self.qkv_layer = nn.Linear(self.dimension_length, 3 * self.dimension_length)
    self.linear_layer = nn.Linear(self.dimension_length, self.dimension_length)

  def forward(self, x, mask):
    batch_size, sequence_length, dimension_length = x.size()
    qkv = self.qkv_layer(x)
    qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
    qkv = qkv.permute(0, 2, 1, 3)
    q, k, v = qkv.chunk(3, dim=-1)
    values, attention = scaled_dot_product(q, k, v, mask)
    values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
    output = self.linear_layer(values)
    return output

In [562]:
class LayerNormalization(nn.Module):
  def __init__(self, parameters_shape, eps=1e-2):
    super().__init__()
    self.parameters_shape = parameters_shape
    self.eps = eps
    self.mean = nn.Parameter(torch.zeros(parameters_shape))
    self.std = nn.Parameter(torch.ones(parameters_shape))

  def forward(self, inputs):
    dims = [-(i + 1) for i in range(len(self.parameters_shape))]
    mean = inputs.mean(dim=dims, keepdim=True)
    var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
    std = (var + self.eps).sqrt()
    y = (inputs - mean) / std
    output = self.std * y + self.mean
    return output

In [563]:
class FeedForward(nn.Module):
  def __init__(self, dimension_length, hidden_neurons, drop_prob=0.1):
    super(FeedForward, self).__init__()
    self.linear_layer0 = nn.Linear(dimension_length, hidden_neurons)
    self.linear_layer1 = nn.Linear(hidden_neurons, dimension_length)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p=drop_prob)

  def forward(self, x):
    x = self.linear_layer0(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.linear_layer1(x)
    return x

In [564]:
class EncoderLayer(nn.Module):
  def __init__(self, dimension_length, hidden_neurons, num_heads, drop_prob):
    super(EncoderLayer, self).__init__()
    self.attention = MultiHeadAttention(dimension_length=dimension_length, num_heads=num_heads)
    self.norm0 = LayerNormalization(parameters_shape=[dimension_length])
    self.dropout0 = nn.Dropout(p=drop_prob)
    self.ffn = FeedForward(dimension_length=dimension_length, hidden_neurons=hidden_neurons, drop_prob=drop_prob)
    self.norm1 = LayerNormalization(parameters_shape=[dimension_length])
    self.dropout1 = nn.Dropout(p=drop_prob)

  def forward(self, x, self_attention_mask):
    residual_x = x.clone()
    x = self.attention(x, mask=self_attention_mask)
    x = self.dropout0(x)
    x = self.norm0(x + residual_x)
    residual_x = x.clone()
    x = self.ffn(x)
    x = self.dropout1(x)
    x = self.norm1(x + residual_x)
    return x

In [565]:
class SequentialEncoder(nn.Sequential):
  def forward(self, *inputs):
    x, self_attention_mask  = inputs
    for module in self._modules.values():
        x = module(x, self_attention_mask)
    return x

In [566]:
class Encoder(nn.Module):
  def __init__(
      self,
      dimension_length,
      hidden_neurons,
      num_heads,
      drop_prob,
      num_layers,
      max_sequence_length,
      language_to_index,
      START_TOKEN,
      END_TOKEN,
      PAD_TOKEN
  ):
    super().__init__()
    self.sentence_embedding = SentenceEmbedding(max_sequence_length, dimension_length, language_to_index, START_TOKEN, END_TOKEN, PAD_TOKEN)
    self.layers = SequentialEncoder(*[EncoderLayer(dimension_length, hidden_neurons, num_heads, drop_prob) for _ in range(num_layers)])

  def forward(self, x, self_attention_mask, start_token, end_token):
    x = self.sentence_embedding(x, start_token, end_token)
    x = self.layers(x, self_attention_mask)
    return x

In [567]:
class MultiHeadCrossAttention(nn.Module):
  def __init__(self, dimension_length, num_heads):
    super().__init__()
    self.dimension_length = dimension_length
    self.num_heads = num_heads
    self.head_dims = self.dimension_length // self.num_heads
    self.kv_layer = nn.Linear(self.dimension_length, 2 * self.dimension_length)
    self.q_layer = nn.Linear(self.dimension_length, self.dimension_length)
    self.linear_layer = nn.Linear(self.dimension_length, self.dimension_length)

  def forward(self, x, y, mask):
    batch_size, sequence_length, dimension_length = x.size()
    kv = self.kv_layer(x)
    q = self.q_layer(y)
    kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dims)
    q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dims)
    kv = kv.permute(0, 2, 1, 3)
    q = q.permute(0, 2, 1, 3)
    k, v = kv.chunk(2, dim=-1)
    values, attention = scaled_dot_product(q, k, v, mask)
    values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, dimension_length)
    output = self.linear_layer(values)
    return output

In [568]:
class DecoderLayer(nn.Module):
  def __init__(self, dimension_length, hidden_neurons, num_heads, drop_prob):
    super(DecoderLayer, self).__init__()
    self.self_attention = MultiHeadAttention(dimension_length=dimension_length, num_heads=num_heads)
    self.layer_norm0 = LayerNormalization(parameters_shape=[dimension_length])
    self.dropout0 = nn.Dropout(p=drop_prob)

    self.cross_attention = MultiHeadCrossAttention(dimension_length=dimension_length, num_heads=num_heads)
    self.layer_norm1 = LayerNormalization(parameters_shape=[dimension_length])
    self.dropout1 = nn.Dropout(p=drop_prob)

    self.ffn = FeedForward(dimension_length=dimension_length, hidden_neurons=hidden_neurons, drop_prob=drop_prob)
    self.layer_norm2 = LayerNormalization(parameters_shape=[dimension_length])
    self.dropout2 = nn.Dropout(p=drop_prob)

  def forward(self, x, y, self_attention_mask, cross_attention_mask):
    residual_y = y.clone()
    y = self.self_attention(y, mask=self_attention_mask)
    y = self.dropout0(y)
    y = self.layer_norm0(y + residual_y)

    residual_y = y.clone()
    y = self.cross_attention(x, y, mask=cross_attention_mask)
    y = self.dropout1(y)
    y = self.layer_norm1(y + residual_y)

    residual_y = y.clone()
    y = self.ffn(y)
    y = self.dropout2(y)
    y = self.layer_norm2(y + residual_y)
    return y

In [569]:
class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, self_attention_mask, cross_attention_mask = inputs
        for module in self._modules.values():
            y = module(x, y, self_attention_mask, cross_attention_mask)
        return y

In [570]:
class Decoder(nn.Module):
  def __init__(
      self,
      dimension_length,
      hidden_neurons,
      num_heads,
      drop_prob,
      num_layers,
      max_sequence_length,
      language_to_index,
      START_TOKEN,
      END_TOKEN,
      PAD_TOKEN
  ):
    super().__init__()
    self.sentence_embedding = SentenceEmbedding(max_sequence_length, dimension_length, language_to_index, START_TOKEN, END_TOKEN, PAD_TOKEN)
    self.layers = SequentialDecoder(*[DecoderLayer(dimension_length, hidden_neurons, num_heads, drop_prob) for _ in range(num_layers)])

  def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
    y = self.sentence_embedding(y, start_token, end_token)
    y = self.layers(x, y, self_attention_mask, cross_attention_mask)
    return y

In [571]:
class Transformer(nn.Module):
  def __init__(
      self,
      dimension_length,
      hidden_neurons,
      num_heads,
      drop_prob,
      num_layers,
      max_sequence_length,
      vocab_size,
      source_to_index,
      target_to_index,
      START_TOKEN,
      END_TOKEN,
      PAD_TOKEN
  ):
    super().__init__()
    self.encoder = Encoder(dimension_length, hidden_neurons, num_heads, drop_prob, num_layers, max_sequence_length, source_to_index, START_TOKEN, END_TOKEN, PAD_TOKEN)
    self.decoder = Decoder(dimension_length, hidden_neurons, num_heads, drop_prob, num_layers, max_sequence_length, target_to_index, START_TOKEN, END_TOKEN, PAD_TOKEN)
    self.linear = nn.Linear(dimension_length, vocab_size)
    self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

  def forward(
      self,
      x,
      y,
      encoder_self_attention_mask=None,
      decoder_self_attention_mask=None,
      decoder_cross_attention_mask=None,
      enc_start_token=False,
      enc_end_token=False,
      dec_start_token=False,
      dec_end_token=False
  ):
    x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, end_token=enc_end_token)
    x = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, start_token=dec_start_token, end_token=dec_end_token)
    output = self.linear(x)
    return output

In [572]:
d_model = 512
batch_size = 15
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 1800
vocab_size = len(vocab)

transformer = Transformer(
    d_model,
    ffn_hidden,
    num_heads,
    drop_prob,
    num_layers,
    max_sequence_length,
    vocab_size,
    char_to_index,
    char_to_index,
    START_TOKEN,
    END_TOKEN,
    PAD_TOKEN
)

In [573]:
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SentenceEmbedding(
      (embedding): Embedding(2326, 512)
      (positional_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm0): LayerNormalization()
        (dropout0): Dropout(p=0.1, inplace=False)
        (ffn): FeedForward(
          (linear_layer0): Linear(in_features=512, out_features=2048, bias=True)
          (linear_layer1): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embeddi

In [574]:
from torch.utils.data import Dataset, DataLoader

class LyricDataset(Dataset):
    def __init__(self, input_lyrics, output_lyrics):
        self.input_lyrics = input_lyrics
        self.output_lyrics = output_lyrics

    def __len__(self):
        return len(self.input_lyrics)

    def __getitem__(self, index):
        return self.input_lyrics[index], self.output_lyrics[index]

In [575]:
train_dataset = LyricDataset(train_ilyrics, train_olyrics)
test_dataset = LyricDataset(test_ilyrics, test_olyrics)

In [576]:
train_loader = DataLoader(train_dataset, batch_size)
test_loader = DataLoader(test_dataset, batch_size)

In [580]:
iterator = iter(train_loader)
for batch_num, batch in enumerate(iterator):
    print(batch)
    break

[("baby got love for thee so deep inside of me don't know where to start yeah yeah love you more than anything but the words can't even touch what's in my heart no noprewhen try to explain it be sounding insane the words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words they only complicate it baby baby ohwoah ohwoah baby i'm so down for you no matter what you do real talk i'll be around yeah yeah ooh baby baby been feeling you before even knew what feelings were about oh babyprewhen try to explain it be sounding all crazy words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it wor

In [581]:
criterian = nn.CrossEntropyLoss(ignore_index=char_to_index[PAD_TOKEN], reduction='none')

for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [582]:
NEG_INFINITY = -1e9

def create_masks(input_batch, output_batch):
    num_lyrics = len(input_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length], True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_lyrics, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_lyrics, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_lyrics, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_lyrics):
        input_lyric_length, output_lyric_length = len(input_batch[idx]), len(output_batch[idx])
        input_chars_to_padding_mask = np.arange(input_lyric_length + 1, max_sequence_length)
        output_chars_to_padding_mask = np.arange(output_lyric_length + 1, max_sequence_length)
        encoder_padding_mask[idx, :, input_chars_to_padding_mask] = True
        encoder_padding_mask[idx, input_chars_to_padding_mask, :] = True
        decoder_padding_mask_self_attention[idx, :, output_chars_to_padding_mask] = True
        decoder_padding_mask_self_attention[idx, output_chars_to_padding_mask, :] = True
        decoder_padding_mask_cross_attention[idx, :, input_chars_to_padding_mask] = True
        decoder_padding_mask_cross_attention[idx, output_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFINITY, 0)
    decoder_self_attention_mask = torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFINITY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFINITY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [583]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        input_batch, output_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(input_batch, output_batch)
        optim.zero_grad()
        output_predictions = transformer(
            input_batch,
            output_batch,
            encoder_self_attention_mask.to(device),
            decoder_self_attention_mask.to(device),
            decoder_cross_attention_mask.to(device),
            enc_start_token=False,
            enc_end_token=False,
            dec_start_token=True,
            dec_end_token=True
        )
        labels = transformer.decoder.sentence_embedding.batch_tokenize(output_batch, start_token=False, end_token=True)
        loss = criterian(output_predictions.view(-1, vocab_size).to(device), labels.view(-1).to(device)).to(device)
        valid_indicies = torch.where(labels.view(-1) == char_to_index[PAD_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"Input: {input_batch[0]}")
            print(f"Output: {output_batch[0]}")
            output_predicted = torch.argmax(output_predictions[0], axis=1)
            predicted_lyric = ""
            for idx in output_predicted:
              if idx == char_to_index[END_TOKEN]:
                break
              predicted_lyric += index_to_char[idx.item()]
            print(f"Prediction: {predicted_lyric}")
            transformer.eval()

Epoch 0
Iteration 0 : 7.907369613647461
Input: baby got love for thee so deep inside of me don't know where to start yeah yeah love you more than anything but the words can't even touch what's in my heart no noprewhen try to explain it be sounding insane the words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my everything baby but every time try to say it words they only complicate it baby baby ohwoah ohwoah baby i'm so down for you no matter what you do real talk i'll be around yeah yeah ooh baby baby been feeling you before even knew what feelings were about oh babyprewhen try to explain it be sounding all crazy words don't ever come out right get all tonguetied and twisted can't explain what i'm feeling and say baby baby ohwoah ohwoah baby baby oh baby oh baby my baby baby oh baby baby baby all i'm tryna say is you're my ever