In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
import math
from torch.utils.data import Dataset, DataLoader

In [2]:
!git clone https://github.com/czhuang/JSB-Chorales-dataset.git

Cloning into 'JSB-Chorales-dataset'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 46 (delta 4), reused 10 (delta 4), pack-reused 36 (from 1)[K
Receiving objects: 100% (46/46), 2.78 MiB | 13.24 MiB/s, done.
Resolving deltas: 100% (12/12), done.


In [3]:
with open('JSB-Chorales-dataset/jsb-chorales-16th.pkl', 'rb') as p:
    data = pickle.load(p, encoding="latin1")

In [10]:
n_embd = 128
n_head = 8
head_size = n_embd // n_head
block_size=8
dropout=0.2


In [5]:
data.keys()

dict_keys(['test', 'train', 'valid'])

In [6]:
print(np.array(data['train'],dtype="object").shape)
print(np.array(data['test'],dtype="object").shape)
print(np.array(data['valid'],dtype="object").shape)

(229,)
(77,)
(76,)


In [7]:
train_data = data['train']
print(f"Type of train_data: {type(train_data)}")
print(f"Number of sequences in train_data: {len(train_data)}")
print(f"Type of the first sequence: {type(train_data[0])}")
print(f"Length of the first sequence: {len(train_data[0])}")
print(f"Type of the first element in the first sequence: {type(train_data[0][0])}")
print(f"Content of the first element in the first sequence:\n{train_data[0][0]}")

Type of train_data: <class 'list'>
Number of sequences in train_data: 229
Type of the first sequence: <class 'list'>
Length of the first sequence: 192
Type of the first element in the first sequence: <class 'tuple'>
Content of the first element in the first sequence:
(np.int64(74), np.int64(70), np.int64(65), np.int64(58))


In [8]:
class Head(nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)

    # Learnable relative positional embeddings
    self.relative_bias = nn.Parameter(torch.randn(block_size, block_size))

  def forward(self, x):
    B, T, C = x.shape
    k = self.key(x)   # (B, T, head_size)
    q = self.query(x) # (B, T, head_size)
    v = self.value(x) # (B, T, head_size)

    # Compute attention scores ("affinities")
    # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
    wei = q @ k.transpose(-2, -1) * C**-0.5

    # Add relative positional bias
    wei = wei + self.relative_bias[:T, :T]

    # Apply causal masking
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)

    # Apply softmax and dropout
    wei = F.softmax(wei, dim=-1) # (B, T, T)
    wei = self.dropout(wei)

    # Perform the weighted aggregation of the values
    out = wei @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
    return out

In [9]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads=n_head, head_size=n_embd//n_head):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd, n_embd)
    self.dropout = nn.Dropout(dropout) # Add dropout layer

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.proj(out)
    out = self.dropout(out)
    return out

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, n_embd, n_head, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(n_head, head_size)
        self.norm1 = nn.LayerNorm(n_embd)
        self.norm2 = nn.LayerNorm(n_embd)

        self.ffwd = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),  # expand (feed-forward inner layer)
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_embd * 4, n_embd),  # project back
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Pre-LN Attention
        x = x + self.attn(self.norm1(x))

        # Pre-LN Feed Forward
        x = x + self.ffwd(self.norm2(x))

        return x
