In [None]:
#! pip install pyarrow matplotlib sentencepiece pandas

In [1]:
import torch
import numpy as np
import math
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.nn import TransformerDecoder, TransformerDecoderLayer
import sentencepiece as spm
import matplotlib.pyplot as plt
import multiprocessing
import time
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

if torch.backends.mps.is_available():  # Check for Apple Silicon GPU availability (requires PyTorch 1.12 or later)
    device = torch.device("mps")
elif torch.cuda.is_available():  # Check for NVIDIA GPU availability
    device = torch.device("cuda")
    torch.cuda.empty_cache()
else:
    device = torch.device("cpu")  # Fall back to CPU

print(f"Using device: {device}")

Using device: mps


In [None]:
# allfiles = [
#     '../train/0000.parquet',
#     # '../train/0001.parquet',
#     # '../train/0002.parquet',
#     # '../train/0003.parquet',
#     # '../validate/0000.parquet',
#     ]

# text_column = 'text'

# # Initialize an empty list to store text
# all_text = []

# for file in allfiles:
#     # Read the parquet file
#     df = pd.read_parquet(file)
#     # Append the text data to the list
#     all_text.extend(df[text_column].tolist())

# # Optional: Save all text to a single file if preferred
# with open('0000_text.txt', 'w', encoding='utf-8') as f:
#     for line in all_text:
#         f.write(f"{line}\n")

# Train SentencePiece model directly from Python list of sentences
# spm.SentencePieceTrainer.train(input='all_text.txt', model_prefix='spm_full_text_model', vocab_size=10000, model_type='unigram')
# spm.SentencePieceTrainer.train(input='0000_text.txt', model_prefix='spm_0000_text_model', vocab_size=10000, model_type='unigram')

In [2]:
train = pd.read_parquet('../train/0000.parquet')
# Select only the first 10000 rows
train = train.iloc[:10000]
train['text'] = train['text'].str.slice(start=0, stop=500)

# validate = pd.read_parquet('../validate/0000.parquet')

In [3]:
class StoryDataset(Dataset):
    def __init__(self, stories, sp_model_path):
        self.stories = stories
        # print(self.stories)
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(sp_model_path)

    def __len__(self):
        return len(self.stories)

    def __getitem__(self, idx):
        encoded_story = [11000] + self.sp.EncodeAsIds(self.stories[idx])
        encoded_target = self.sp.EncodeAsIds(self.stories[idx]) + [12000]
        return torch.tensor(encoded_story, dtype=torch.long), torch.tensor(encoded_target, dtype=torch.long)

In [5]:
def collate_fn(batch):
    stories, targets = zip(*batch)
    
    # Padding sequences to have the same length within a batch
    padded_stories = pad_sequence(stories, batch_first=True, padding_value=13000)  # Assuming 0 is your padding ID
    padded_targets = pad_sequence(targets, batch_first=True, padding_value=13000)  # Adjust padding_value if necessary
    
    return padded_stories, padded_targets

In [17]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        
        # self.queries = nn.Linear(embed_size, embed_size)
        # self.keys = nn.Linear(embed_size, embed_size)
        # self.values = nn.Linear(embed_size, embed_size)

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    
    def forward(self, value, key, query, mask = None):
        N = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = value.reshape(N, value_len, self.heads, self.head_dim)
        keys = key.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does matrix multiplication for query and keys for each training example
        # and each head, with scaling by the dimension's square root

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # print("energy: ", energy.shape)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        
                    # Generate query, key, value matrices
                    # Q = self.queries(x) # Batch size : BatchTokensLen: EmbedingSize
                    # K = self.keys(x)    # Batch size : BatchTokensLen: EmbedingSize
                    # V = self.values(x)  # Batch size : BatchTokensLen: EmbedingSize
             
                    # # Calculate the attention scores
                    # attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5)
                   
                    # if mask is not None:
                    #     attention_scores = attention_scores.masked_fill(mask == 0, float("-1e20"))
       
                    # attention = F.softmax(attention_scores, dim=-1)
       
                    # Apply attention scores to values
                    # out = torch.matmul(attention, V)
       
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads):
        super(TransformerBlock, self).__init__()
        # self.attention = SelfAttention(embed_size)
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size)
        )
    
    def forward(self, value, key, query, mask):
        # Self-attention
        attention_out = self.attention(value, key, query, mask)
        x = self.norm1(attention_out + query)
        
        # Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm2(ff_out + x)
        return x
        
def positional_encoding(max_seq_len, embed_size, device):
    
    # Initialize the positional encoding matrix
    pos_enc = torch.zeros((max_seq_len, embed_size), device=device)
    # Use torch.arange to generate positions (0 to max_seq_len-1)
    positions = torch.arange(0, max_seq_len, dtype=torch.float, device=device).unsqueeze(1)
    # Divisor to apply to the positions
    div_term = torch.exp(torch.arange(0, embed_size, 2, device=device).float() * -(math.log(10000.0) / embed_size))

    pos_enc[:, 0::2] = torch.sin(positions * div_term)  # Apply sin to even indices
    pos_enc[:, 1::2] = torch.cos(positions * div_term)  # Apply cos to odd indices

    return pos_enc.unsqueeze(0)  # Add batch dimension for convenience

class SimpleTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, max_length, heads, device):
        super(SimpleTransformer, self).__init__()
        self.embed = nn.Embedding(13001, embed_size)  # +3 for SOS, EOS, and padding
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_length, embed_size))
        self.device = device
        self.embed_size = embed_size
        self.max_length = max_length
        
        self.transformer_block = TransformerBlock(embed_size, heads)
        self.fc_out = nn.Linear(embed_size, 13001)  # +3 for SOS, EOS, and padding
    
    
    def forward(self, x):
        # print(x.shape)
        batch_size, seq_len = x.shape
        positions = torch.arange(0, seq_len).unsqueeze(0).to(self.device)
        
        # x != 13000 creates a mask where padded positions are False
        # unsqueeze operations adjust the mask dimensions to match the attention scores tensor
        # mask = (x != 13000).unsqueeze(1).unsqueeze(2)

        padding_mask = (x != 13000) # Identify padding tokens, shape: [batch_size, seq_len]
        print(padding_mask)
        # print("x ", x.shape, "padding_mask: ", padding_mask.shape) 
        # padding_mask = padding_mask.unsqueeze(1)  # Add dimension for broadcasting, shape: [batch_size, 1, seq_len]
        padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)  # Shape: [N, 1, 1, seq_len]

         # Create look-ahead mask
        look_ahead_mask = torch.triu(torch.ones((1,1, seq_len, seq_len), device=self.device), diagonal=1).bool()
        print(look_ahead_mask)
        # print("padding_mask: ", padding_mask.shape)
        # print("look_ahead_mask: ", look_ahead_mask.shape)
        
        # look_ahead_mask = look_ahead_mask.unsqueeze(0)  # Add batch dimension [1, seq_len, seq_len]
        combined_mask = padding_mask & look_ahead_mask  # Use logical OR to combine both masks
        # print("combined_mask: ", combined_mask.shape)
        
        # Create a mask for padding tokens
        x = self.embed(x)
        # pos_encoding = self.positional_encoding[:, :x.size(1), :]
        pos_encoding = positional_encoding(seq_len, self.embed_size, self.device)
        # print("pos_encoding: ", pos_encoding.shape)
        # print("x: ",x.shape)
        x = x + pos_encoding
        
        # print(x.shape,padding_mask.shape, look_ahead_mask.shape)
        # print(x,padding_mask, look_ahead_mask)
        
        # Combine masks: Ensure the combined mask is broadcastable to the attention scores shape
        # combined_mask = torch.max(padding_mask, look_ahead_mask)  # Element-wise max to combine masks

        # Since both masks are used in the same attention mechanism now, you can combine them
        # Ensure broadcasting works by aligning dimensions: [N, 1, seq_len, seq_len] for both masks
        
        x = self.transformer_block(x, x, x, combined_mask)
        out = self.fc_out(x)
        return out
    
# Example initialization parameters
vocab_size = 10000  # Not including special tokens
embed_size = 256
max_length = 2000 # Adjust based on your sequence length requirements

model = SimpleTransformer(embed_size, vocab_size, max_length, 8, device)
# Move your model to the selected device
model.to(device)

SimpleTransformer(
  (embed): Embedding(13001, 256)
  (transformer_block): TransformerBlock(
    (attention): SelfAttention(
      (values): Linear(in_features=32, out_features=32, bias=False)
      (keys): Linear(in_features=32, out_features=32, bias=False)
      (queries): Linear(in_features=32, out_features=32, bias=False)
      (fc_out): Linear(in_features=256, out_features=256, bias=True)
    )
    (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (feed_forward): Sequential(
      (0): Linear(in_features=256, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=256, bias=True)
    )
  )
  (fc_out): Linear(in_features=256, out_features=13001, bias=True)
)

In [18]:
def masked_cross_entropy(logits, target, mask):
    logits_flat = logits.view(-1, logits.size(-1))
    target_flat = target.view(-1)
    mask_flat = mask.view(-1)
    
    losses = F.cross_entropy(logits_flat, target_flat, reduction='none')
    masked_losses = losses * mask_flat
    return masked_losses.sum() / mask_flat.sum()

In [19]:
BATCH_SIZE = 8
NUM_WORKERS = 8  # Based on your system's capabilities

# Assuming `train_stories`, `train_targets` are your training data and targets
train_dataset = StoryDataset(train['text'], "spm_0000_text_model.model")
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_fn,
    # num_workers=NUM_WORKERS
)

for batch_idx, (stories, targets) in enumerate(train_loader):
    print(stories.shape)
    print(targets.shape)
    break

torch.Size([8, 122])
torch.Size([8, 122])


In [20]:
def validation(model):
    initial_text = "Once upon a time "
    sp = spm.SentencePieceProcessor()
    sp.Load("spm_0000_text_model.model")
    encoded_input =  sp.EncodeAsIds(initial_text)
    print(encoded_input)
    input_tensor = torch.tensor([encoded_input], dtype=torch.long).to(device)  # Assuming batch_size=1
    
    max_length_ = 50  # Maximum length of the generated sequence
    eos_token_id = 12000  # Assuming '</s>' is your end-of-sequence token
    
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        for _ in range(max_length_):
            output = model(input_tensor)  # Get the model's predictions
            # Get the last predicted token (logits) as the next token. Shape of output: (batch_size, sequence_length, vocab_size)
          
            # Choose a random token from the top 5
            last_token_logits = output[:, -1, :]


            last_token_logits = output[:, -1, :]
            predicted_token_id = torch.argmax(last_token_logits, dim=-1).unsqueeze(0)  # 
            
            # print("predicted_token_id: ", predicted_token_id.shape)
            # print("predicted_token_id: ", predicted_token_id)
            
            topk_values, topk_indices = torch.topk(last_token_logits, 3, dim=-1)
            # print("topk_values: ", topk_values)
            # print("topk_indices: ", topk_indices)
                # print(last_token_logits, topk_values, topk_indices)
                # probs = torch.softmax(topk_values, dim=-1)
                # print("probs: ", probs)
                # sampled_index = torch.multinomial(probs, 1)
            sampled_index = torch.randint(0, topk_indices.size(-1), (1,), device=device)
            # print("sampled_index: ", sampled_index)
            topk_indices = topk_indices.squeeze()
            # print("topk_indices sqqq: ", topk_indices)
            predicted_token_id = topk_indices.gather(-1, sampled_index).squeeze()
            # print("predicted_token_id: ", predicted_token_id)
            predicted_token_id = topk_indices.gather(-1, sampled_index).unsqueeze(0)
            # print("predicted_token_id: ", predicted_token_id)
            # print("input_tensor: ", input_tensor)
            
            
            
            
                # sampled_index = torch.multinomial(probs, 1)

            
                # predicted_token_id = topk_indices.gather(-1, sampled_index).squeeze()
                
                # predicted_token_id = torch.argmax(last_token_logits, dim=-1).unsqueeze(0)  # Choose the token with the highest probability
            # print("input_tensor:", input_tensor[0, -1].item())
            # print("predicted_token_id:", predicted_token_id.item())
            # Check if the last element of input_tensor matches predicted_token_id
            if input_tensor[0, -1].item() == predicted_token_id.item():
                continue  # Skip this iteration
            
            # Append the predicted token ID to the input (which is fed into the model in the next iteration)
            input_tensor = torch.cat((input_tensor, predicted_token_id), dim=1)
            # print("input_tensor: ", input_tensor)
            # break
            # input_tensor = torch.cat((input_tensor, predicted_token_id.unsqueeze(0)), dim=1)
            
            # Check if the end-of-sequence token was generated
            if predicted_token_id.item() == eos_token_id:
                break
    
    model.train()
    # Decode the generated sequence back to text
    generated_sequence = []
    print(input_tensor.squeeze().tolist())
    for token_id in input_tensor.squeeze().tolist():
        if token_id <= 10000:
            generated_sequence.append(sp.IdToPiece(token_id))
    
    generated_text = ''.join(generated_sequence).replace('▁', ' ').strip()  # Assuming '▁' is the SentencePiece underline character
    print(generated_text)
validation(model)


[56, 60, 8, 37]
tensor([[True, True, True, True]], device='mps:0')
tensor([[[[False,  True,  True,  True],
          [False, False,  True,  True],
          [False, False, False,  True],
          [False, False, False, False]]]], device='mps:0')
tensor([[True, True, True, True, True]], device='mps:0')
tensor([[[[False,  True,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False, False,  True,  True],
          [False, False, False, False,  True],
          [False, False, False, False, False]]]], device='mps:0')
tensor([[True, True, True, True, True, True]], device='mps:0')
tensor([[[[False,  True,  True,  True,  True,  True],
          [False, False,  True,  True,  True,  True],
          [False, False, False,  True,  True,  True],
          [False, False, False, False,  True,  True],
          [False, False, False, False, False,  True],
          [False, False, False, False, False, False]]]], device='mps:0')
tensor([[True, True, True, True, True

In [16]:
# Assuming you have defined a model instance called `model`
# Setting a custom starting learning rate for the Adam optimizer
learning_rate = 1e-3  # Example: A smaller learning rate
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()  # Adjust based on your specific task

EPOCHS = 10  # Number of epochs to train for

model.train()  # Set the model to training mode
for epoch in range(EPOCHS):
    for batch_idx, (stories, targets) in enumerate(train_loader):
        stories = stories.to(device)
        targets = targets.to(device)
        
        mask = (targets != 13000).float().to(device)
        
        optimizer.zero_grad()  # Zero the gradients
        output = model(stories)  # Forward pass: compute the output
    
        loss = masked_cross_entropy(output, targets, mask)

        loss.backward()  # Backward pass: compute gradient of the loss with respect to model parameters
        optimizer.step()  # Perform a single optimization step (parameter update)
        # print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')
        break
        if batch_idx % 100 == 0:  # Print loss every 100 batches
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')
            validation(model)
        
    break
    # Validation step
    # val_loss = validate(model, validation_loader, criterion)
    # print(f'Epoch {epoch}, Training Loss: {loss.item()}, Validation Loss: {val_loss}')
    # Validation loop can go here
    # Remember to set model.eval() and torch.no_grad() during validation

tensor([[[[False,  True,  True,  ...,  True,  True,  True],
          [False, False,  True,  ...,  True,  True,  True],
          [False, False, False,  ...,  True,  True,  True],
          ...,
          [False, False, False,  ..., False,  True,  True],
          [False, False, False,  ..., False, False,  True],
          [False, False, False,  ..., False, False, False]]]], device='mps:0')


RuntimeError: required rank 4 tensor to use channels_last format