In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import requests
from transformers import AutoTokenizer

print(torch.backends.mps.is_available())  # Should print True
print(torch.backends.mps.is_built())      # Should print True

# device = torch.device('mps')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device', device)

False
False
device cuda


In [8]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [2]:
r = requests.get("https://s3.amazonaws.com/text-datasets/nietzsche.txt")
nietzsche_corpus = r.text

In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def tokenize_text(text):
  encoded = tokenizer.encode(
      text,
      add_special_tokens=True,
      max_length=512,
      padding='max_length',
      truncation=True,
      return_tensors='pt'
  )

  return encoded.squeeze()

In [4]:
class NietzscheDataset(Dataset):
  def __init__(self, text, tokenizer, sequence_length=511):  # 511 + 1 special token = 512
    self.sequence_length = sequence_length
    self.tokenizer = tokenizer

    chunk_size = sequence_length * 100
    text_chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

    all_tokens = []
    for chunk in text_chunks:
      tokens = tokenizer.encode(
          chunk,
          add_special_tokens=True,
          max_length=512,
          truncation=True,
          return_tensors='pt'
      )
      all_tokens.extend(tokens.squeeze().tolist())

    self.tokens = torch.tensor(all_tokens)
    self.num_sequences = len(self.tokens) - sequence_length - 1

  def __len__(self):
    return self.num_sequences

  def __getitem__(self, idx):
    sequence = self.tokens[idx:idx + self.sequence_length]
    target = self.tokens[idx + 1:idx + self.sequence_length + 1]
    return sequence, target


dataset = NietzscheDataset(nietzsche_corpus, tokenizer)

dataloader = DataLoader(
  dataset,
  batch_size=32,
  shuffle=True,
  num_workers=0
)

In [5]:
class NietzscheLSTM(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=3):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.lstm = nn.LSTM(
      input_size=embedding_dim,
      hidden_size=hidden_size,
      num_layers=num_layers,
      dropout=0.2,
      batch_first=True
    )
    self.fc = nn.Linear(hidden_size, vocab_size)

  def forward(self, idx, hidden=None):
    embeddings = self.embedding(idx)
    output, hidden = self.lstm(embeddings, hidden)
    logits = self.fc(output)
    return logits, hidden


model = NietzscheLSTM(vocab_size=tokenizer.vocab_size, embedding_dim=256, hidden_size=512).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

num_epochs = 50
early_stop_threshold = 0.1
patience = 3
best_loss = float('inf')
patience_counter = 0

for epoch in range(num_epochs):
  model.train()
  total_loss = 0

  for batch_idx, (sequences, targets) in enumerate(dataloader):
    sequences = sequences.to(device)
    targets = targets.to(device)

    optimizer.zero_grad()

    logits, _ = model.forward(sequences)

    logits = logits.view(-1, tokenizer.vocab_size)
    targets = targets.view(-1)

    loss = criterion(logits, targets)
    loss.backward()

    optimizer.step()

    total_loss += loss.item()
    if batch_idx % 100 == 0:
      print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')

  avg_loss = total_loss / len(dataloader)
  print(f'Epoch {epoch} completed. Average loss: {avg_loss:.4f}')

  if avg_loss < early_stop_threshold:
    print(f'Loss {avg_loss:.4f} below threshold {early_stop_threshold}. Stopping early.')
    break

  if avg_loss < best_loss:
    best_loss = avg_loss
    patience_counter = 0
  else:
    patience_counter += 1
    if patience_counter >= patience:
      print(f'No improvement for {patience} epochs. Stopping early.')
      break

Epoch: 0, Batch: 0, Loss: 10.3277
Epoch: 0, Batch: 100, Loss: 6.0928
Epoch 0 completed. Average loss: 6.2412
Epoch: 1, Batch: 0, Loss: 6.1077
Epoch: 1, Batch: 100, Loss: 6.1506
Epoch 1 completed. Average loss: 6.1000
Epoch: 2, Batch: 0, Loss: 6.0903
Epoch: 2, Batch: 100, Loss: 6.1177
Epoch 2 completed. Average loss: 6.0272
Epoch: 3, Batch: 0, Loss: 5.7732
Epoch: 3, Batch: 100, Loss: 5.3699
Epoch 3 completed. Average loss: 5.4118
Epoch: 4, Batch: 0, Loss: 5.1484
Epoch: 4, Batch: 100, Loss: 4.8820
Epoch 4 completed. Average loss: 4.8492
Epoch: 5, Batch: 0, Loss: 4.5295
Epoch: 5, Batch: 100, Loss: 4.1693
Epoch 5 completed. Average loss: 4.2396
Epoch: 6, Batch: 0, Loss: 3.9329
Epoch: 6, Batch: 100, Loss: 3.5364
Epoch 6 completed. Average loss: 3.5850
Epoch: 7, Batch: 0, Loss: 3.2327
Epoch: 7, Batch: 100, Loss: 2.7578
Epoch 7 completed. Average loss: 2.8147
Epoch: 8, Batch: 0, Loss: 2.4283
Epoch: 8, Batch: 100, Loss: 2.1190
Epoch 8 completed. Average loss: 2.1467
Epoch: 9, Batch: 0, Loss: 1

In [10]:
@torch.no_grad()
def generate(model, tokenizer, start_sequence, max_new_tokens, temperature=0.8):
  model.eval()
  device = model.embedding.weight.device

  x = start_sequence[start_sequence != tokenizer.pad_token_id]
  x = x.to(device)

  start_text = tokenizer.decode(x.tolist(), skip_special_tokens=True)
  print(start_text, end=' ', flush=True)  # Add space after start text

  current_text = ''

  for _ in range(max_new_tokens):
    logits, _ = model(x.view(1, -1))
    logits = logits[0, -1, :] / temperature

    for special_id in [tokenizer.pad_token_id, tokenizer.sep_token_id, tokenizer.cls_token_id]:
      logits[special_id] = float('-inf')

    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)

    new_text = tokenizer.decode([next_token.item()], skip_special_tokens=True)

    if new_text.startswith('##'):
      current_text += new_text[2:]  # Remove ## and append to current word
    else:
      if current_text:  # If we have accumulated text, print it with a space
        print(current_text, end=' ', flush=True)
        current_text = ''
      current_text = new_text

    x = torch.cat([x, next_token])

    if len(x) > 512:
      x = x[-511:]

  if current_text:  # Print any remaining text
    print(current_text, end=' ', flush=True)
  print()

start_sequence = tokenize_text("Thus spoke Zarathustra")
generated_text = generate(
  model=model,
  tokenizer=tokenizer,
  start_sequence=start_sequence,
  max_new_tokens=2048,
  temperature=0.8
)

thus spoke zarathustra niation as the assertion of decidedlylyous , of justice . let us would be religious that each mental are disposed : as a manfied lights belief has please ,ul everythingen of expertness to apply to nature the same strict science of interpretation that the philologists have devised for all literature , and to apply it for the purpose of a simple , direct interpretation of the message , and at the same time , not bring out a double meaning . but , as in the case of books and literature , errors of exposition seek woman ? and bad ? 115 so genuine , only name - and blood and even , errors of surrenders to the other " the foreground - - and being left home - virtuous men and seizing upon happiness make one is delight how may have lates her kind of fatherlandism , and at the cause of an simple , but they have men ) has assumed delight by - - naturally to the standpoint of self preservation , therefore to the egoism of this consideration : " why should i injure myself to

In [11]:
from google.colab import drive
drive.mount('/content/drive')

torch.save(model.state_dict(), '/content/drive/My Drive/nietzsche_lstm_bert_weights.pth')

Mounted at /content/drive
