In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import requests
from transformers import AutoTokenizer

print(torch.backends.mps.is_available())  # Should print True
print(torch.backends.mps.is_built())      # Should print True

# device = torch.device('mps')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device', device)

False
False
device cuda


In [2]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [3]:
r = requests.get("https://s3.amazonaws.com/text-datasets/nietzsche.txt")
nietzsche_corpus = r.text

In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def tokenize_text(text):
  encoded = tokenizer.encode(
      text,
      add_special_tokens=True,
      max_length=512,
      padding='max_length',
      truncation=True,
      return_tensors='pt'
  )

  return encoded.squeeze()

In [5]:
class NietzscheDataset(Dataset):
  def __init__(self, text, tokenizer, sequence_length=511):  # 511 + 1 special token = 512
    self.sequence_length = sequence_length
    self.tokenizer = tokenizer

    chunk_size = sequence_length * 100
    text_chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

    all_tokens = []
    for chunk in text_chunks:
      tokens = tokenizer.encode(
          chunk,
          add_special_tokens=True,
          max_length=512,
          truncation=True,
          return_tensors='pt'
      )
      all_tokens.extend(tokens.squeeze().tolist())

    self.tokens = torch.tensor(all_tokens)
    self.num_sequences = len(self.tokens) - sequence_length - 1

  def __len__(self):
    return self.num_sequences

  def __getitem__(self, idx):
    sequence = self.tokens[idx:idx + self.sequence_length]
    target = self.tokens[idx + 1:idx + self.sequence_length + 1]
    return sequence, target


dataset = NietzscheDataset(nietzsche_corpus, tokenizer)

dataloader = DataLoader(
  dataset,
  batch_size=32,
  shuffle=True,
  num_workers=0
)

In [6]:
class NietzscheLSTM(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=3, num_heads=8, dropout=0.1):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.lstm = nn.LSTM(
      input_size=embedding_dim,
      hidden_size=hidden_size,
      num_layers=num_layers,
      dropout=0.2,
      batch_first=True
    )
    self.attention = nn.MultiheadAttention(
      embed_dim=hidden_size,
      num_heads=num_heads,
      dropout=dropout,
      batch_first=True
    )

    self.layer_norm = nn.LayerNorm(hidden_size)

    self.dropout = nn.Dropout(dropout)

    self.fc = nn.Linear(hidden_size, vocab_size)

  def forward(self, idx, hidden=None):
    embeddings = self.embedding(idx)
    output, hidden = self.lstm(embeddings, hidden)
    attn_output, attn_weights = self.attention(
      query=output,
      key=output,
      value=output
    )
    attn_output = self.layer_norm(output + attn_output)
    attn_output = self.dropout(attn_output)
    logits = self.fc(attn_output)
    return logits, hidden, attn_weights


model = NietzscheLSTM(vocab_size=tokenizer.vocab_size, embedding_dim=256, hidden_size=512).to(device)

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

num_epochs = 50
early_stop_threshold = 0.01
patience = 3
best_loss = float('inf')
patience_counter = 0

for epoch in range(num_epochs):
  model.train()
  total_loss = 0

  for batch_idx, (sequences, targets) in enumerate(dataloader):
    sequences = sequences.to(device)
    targets = targets.to(device)

    optimizer.zero_grad()

    logits, _, _ = model.forward(sequences)

    logits = logits.view(-1, tokenizer.vocab_size)
    targets = targets.view(-1)

    loss = criterion(logits, targets)
    loss.backward()

    optimizer.step()

    total_loss += loss.item()
    if batch_idx % 100 == 0:
      print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')

  avg_loss = total_loss / len(dataloader)
  print(f'Epoch {epoch} completed. Average loss: {avg_loss:.4f}')

  if avg_loss < early_stop_threshold:
    print(f'Loss {avg_loss:.4f} below threshold {early_stop_threshold}. Stopping early.')
    break

  if avg_loss < best_loss:
    best_loss = avg_loss
    patience_counter = 0
  else:
    patience_counter += 1
    if patience_counter >= patience:
      print(f'No improvement for {patience} epochs. Stopping early.')
      break

Epoch: 0, Batch: 0, Loss: 10.4688
Epoch: 0, Batch: 100, Loss: 0.1071
Epoch 0 completed. Average loss: 1.3464
Epoch: 1, Batch: 0, Loss: 0.0409
Epoch: 1, Batch: 100, Loss: 0.0256
Epoch 1 completed. Average loss: 0.0275
Epoch: 2, Batch: 0, Loss: 0.0179
Epoch: 2, Batch: 100, Loss: 0.0171
Epoch 2 completed. Average loss: 0.0163
Epoch: 3, Batch: 0, Loss: 0.0125
Epoch: 3, Batch: 100, Loss: 0.0144
Epoch 3 completed. Average loss: 0.0122
Epoch: 4, Batch: 0, Loss: 0.0087
Epoch: 4, Batch: 100, Loss: 0.0104
Epoch 4 completed. Average loss: 0.0100
Epoch: 5, Batch: 0, Loss: 0.0085
Epoch: 5, Batch: 100, Loss: 0.0092
Epoch 5 completed. Average loss: 0.0088
Loss 0.0088 below threshold 0.01. Stopping early.


In [8]:
@torch.no_grad()
def generate(model, tokenizer, start_sequence, max_new_tokens, temperature=0.8):
  model.eval()
  device = model.embedding.weight.device

  x = start_sequence[start_sequence != tokenizer.pad_token_id]
  x = x.to(device)

  start_text = tokenizer.decode(x.tolist(), skip_special_tokens=True)
  print(start_text, end=' ', flush=True)  # Add space after start text

  current_text = ''

  for _ in range(max_new_tokens):
    logits, _, _ = model.forward(x.view(1, -1))
    logits = logits[0, -1, :] / temperature

    for special_id in [tokenizer.pad_token_id, tokenizer.sep_token_id, tokenizer.cls_token_id]:
      logits[special_id] = float('-inf')

    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)

    new_text = tokenizer.decode([next_token.item()], skip_special_tokens=True)

    if new_text.startswith('##'):
      current_text += new_text[2:]  # Remove ## and append to current word
    else:
      if current_text:  # If we have accumulated text, print it with a space
        print(current_text, end=' ', flush=True)
        current_text = ''
      current_text = new_text

    x = torch.cat([x, next_token])

    if len(x) > 512:
      x = x[-511:]

  if current_text:  # Print any remaining text
    print(current_text, end=' ', flush=True)
  print()

start_sequence = tokenize_text("Thus spoke Zarathustra")
generated_text = generate(
  model=model,
  tokenizer=tokenizer,
  start_sequence=start_sequence,
  max_new_tokens=2048,
  temperature=0.8
)

thus spoke zarathustra its little do - - a super - european music , which holds its own even in presence of the brown sunsets of the desert , whose soul is akin to the palm - tree , and can be at home and can roam with big , beautiful , lonely beasts of prey . . . i could imagine a music of which the rarest charm would be that it knew nothing more of good and evil ; only that here and there perhaps some sailor ' s home - sickness , some analysis " ( ' s taking the expression in its widest sense ) perhaps not be the exception , but the rule ? - - perhaps genius is by no means so rare : but rather the five hundred hands which it requires in order to tyrannize over the [ greek inserted here ] , " the right time " - - in order to take chance by the forelock ! 275 . he who does not wish to see the height of a man , looks all the more sharply at what is low in him , and in the foreground - - and thereby betrays himself . 276 . in all kinds of injury and loss the lower and coarser soul is bet

In [9]:
from google.colab import drive
drive.mount('/content/drive')

torch.save(model.state_dict(), '/content/drive/My Drive/nietzsche_lstm_bert_attention_weights.pth')

Mounted at /content/drive
