In [20]:
# !pip install torch torchvision torchaudio

In [59]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence


In [73]:
class LyricsDataset(Dataset):
    def __init__(self, file_path, context_window=5):
        with open(file_path, 'r', encoding='utf-8') as file:
            self.lyrics = [line.strip() for line in file if line.strip()]
        self.vocab = sorted(set(' '.join(self.lyrics).split()))
        self.vocab.append('<unk>')  # Add the '<unk>' token to the vocabulary
        self.word2idx = {word: idx for idx, word in enumerate(self.vocab)}
        self.context_window = context_window
    
    def __len__(self):
        return len(self.lyrics)
    
    def __getitem__(self, idx):
        lyric = self.lyrics[idx]
        tokens = lyric.split()
        
        if len(tokens) <= self.context_window:
            return None
        
        input_seq = tokens[:self.context_window]
        target_word = tokens[self.context_window]
        
        input_seq_indices = [self.word2idx.get(token, self.word2idx['<unk>']) for token in input_seq]
        target_word_index = self.word2idx.get(target_word, self.word2idx['<unk>'])
        
        return torch.tensor(input_seq_indices, dtype=torch.long), torch.tensor(target_word_index, dtype=torch.long)

In [74]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer_encoder(x)
        x = self.fc(x)
        return x

In [75]:
# Hyperparameters
hidden_dim = 256
num_layers = 4
num_heads = 4
batch_size = 16
num_epochs = 10
learning_rate = 0.001

In [81]:
def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    
    if len(batch) == 0:
        return None
    
    input_seqs, target_seqs = zip(*batch)
    
    # Pad the input sequences
    padded_input_seqs = pad_sequence(input_seqs, batch_first=True, padding_value=0)
    
    # Stack the target sequences
    target_seqs = torch.stack(target_seqs)
    
    return padded_input_seqs, target_seqs

In [82]:
# Load the dataset
dataset = LyricsDataset('all_tswift_lyrics_cleaned.txt')
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

In [83]:
len(dataset.vocab)

4420

In [84]:
# Instantiate the model, loss function, and optimizer
vocab_size = len(dataset.vocab)
model = SimpleTransformer(vocab_size, hidden_dim, num_layers, num_heads)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Define the device (CPU in this case)
device = torch.device('cpu')

# Move the model to the device
model.to(device)

SimpleTransformer(
  (embedding): Embedding(4420, 256)
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (linear1): Linear(in_features=256, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=256, bias=True)
    (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=2048, bias=True)
        (dro

In [85]:
# Train the model
for epoch in range(num_epochs):
    for batch in dataloader:
        if batch is None:
            continue
        
        input_seq, target_seq = batch
        optimizer.zero_grad()
        input_seq = input_seq.to(device)  # Move input to the device (CPU or GPU)
        target_seq = target_seq.to(device)  # Move target to the device (CPU or GPU)
        
        # Reshape the input_seq tensor
        input_seq = input_seq.view(input_seq.size(0), input_seq.size(1))
        
        outputs = model(input_seq)
        outputs = outputs[:, -1, :]  # Take the last output from each sequence
        target_seq = target_seq.view(-1)
        
        loss = criterion(outputs, target_seq)
        loss.backward()
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

ValueError: Expected input batch_size (65) to match target batch_size (13).

In [None]:
# Evaluate the trained model
with torch.no_grad():
    test_lyrics = "So shame on "
    test_seq = torch.tensor([dataset.word2idx.get(word, dataset.word2idx['<unk>']) for word in test_lyrics.split()], dtype=torch.long)
    output = model(test_seq.unsqueeze(0))
    output_probs = torch.softmax(output[-1, -1], dim=-1)
    predicted_index = torch.multinomial(output_probs, num_samples=1).item()
    
    if predicted_index < len(dataset.vocab):
        predicted_word = dataset.vocab[predicted_index]
        predicted_lyrics = test_lyrics + ' ' + predicted_word
    else:
        predicted_lyrics = test_lyrics
        print(f"Skipping prediction. Predicted index {predicted_index} is out of range.")

    print("Input lyrics:", test_lyrics)
    print("Predicted word:", predicted_word)
    print("Predicted lyrics:", predicted_lyrics)

Input lyrics: So shame on 
Predicted word: now.
Predicted lyrics: So shame on  now.


In [98]:
# Evaluate the trained model
with torch.no_grad():
    test_lyrics = "I knew you were"
    test_seq = torch.tensor([dataset.word2idx.get(word, dataset.word2idx['<unk>']) for word in test_lyrics.split()], dtype=torch.long)
    output = model(test_seq.unsqueeze(0))
    output_probs = torch.softmax(output[-1, -1], dim=-1)
    
    num_predictions = 5
    predicted_indices = torch.multinomial(output_probs, num_samples=num_predictions)
    predicted_words = [dataset.vocab[idx.item()] for idx in predicted_indices]
    predicted_probs = [output_probs[idx.item()].item() for idx in predicted_indices]
    
    print("Input lyrics:", test_lyrics)
    print("Predicted words:", predicted_words)
    print("Predicted probabilities:", predicted_probs)

Input lyrics: I knew you were
Predicted words: ["lookin'", 'at', 'what', 'the', 'have']
Predicted probabilities: [0.020084301009774208, 0.00029687696951441467, 0.08618093281984329, 0.2464924454689026, 0.003065008670091629]
