<a href="https://colab.research.google.com/github/kevinjmann/transformer_summarizer/blob/main/TransformerSummarizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Create a transformer based summarizer.

The dataset comes from here https://github.com/Alex-Fabbri/Multi-News which is part of the tensor flow data set available in the following package

In [None]:
!pip install -q tfds-nightly

In [None]:
import tensorflow_datasets as tfds

In [None]:
builder = tfds.builder('multi_news')

In [None]:
builder.download_and_prepare()

In [None]:
!ls -al /root/tensorflow_datasets/multi_news/1.0.0/

total 686796
drwxr-xr-x 2 root root     4096 Jan 31 02:33 .
drwxr-xr-x 3 root root     4096 Jan 31 02:33 ..
-rw-r--r-- 1 root root     3113 Jan 31 02:33 dataset_info.json
-rw-r--r-- 1 root root      464 Jan 31 02:33 features.json
-rw-r--r-- 1 root root 70697817 Jan 31 02:33 multi_news-test.tfrecord-00000-of-00001
-rw-r--r-- 1 root root 70464259 Jan 31 02:33 multi_news-train.tfrecord-00000-of-00008
-rw-r--r-- 1 root root 69497487 Jan 31 02:33 multi_news-train.tfrecord-00001-of-00008
-rw-r--r-- 1 root root 69628388 Jan 31 02:33 multi_news-train.tfrecord-00002-of-00008
-rw-r--r-- 1 root root 71909427 Jan 31 02:33 multi_news-train.tfrecord-00003-of-00008
-rw-r--r-- 1 root root 70475688 Jan 31 02:33 multi_news-train.tfrecord-00004-of-00008
-rw-r--r-- 1 root root 69650383 Jan 31 02:33 multi_news-train.tfrecord-00005-of-00008
-rw-r--r-- 1 root root 68935097 Jan 31 02:33 multi_news-train.tfrecord-00006-of-00008
-rw-r--r-- 1 root root 73055971 Jan 31 02:33 multi_news-train.tfrecord-00007-of-000

Load all the data in the dataset

In [None]:
train_ds = tfds.as_numpy(tfds.load('multi_news', split='train', as_supervised=True, batch_size=-1))

In [None]:
val_ds = tfds.as_numpy(tfds.load('multi_news', split='validation', as_supervised=True, batch_size=-1))

In [None]:
documents, summaries = train_ds
val_docs, val_summaries = val_ds

In [None]:
# sample document entry
documents[0].decode()

'Flag Flap Underscores Trump\'s Strained Relationship With McCain \n  \n Enlarge this image toggle caption Mandel Ngan/AFP/Getty Images Mandel Ngan/AFP/Getty Images \n  \n Updated at 9:37 p.m. ET \n  \n The beginning of the national memorial for Sen. John McCain, R-Ariz., has been marred by a fight over a sign of public respect, as President Trump initially avoided issuing a proclamation to lower flags to half-staff at all federal properties in McCain\'s honor. \n  \n Flags were lowered at government buildings across Washington and across the country Saturday evening after McCain died, as is standard practice for a sitting member of Congress. \n  \n But on Monday morning the flag atop the White House was back at full-staff, causing some to ask whether Trump\'s strained relationship with McCain had played into the decision to not keep it lowered. The lack of a proclamation was viewed by some as a disrespectful act reflecting the president\'s dislike for McCain, which Trump continued to 

In [None]:
vocab = set()

In [None]:
nchar = 1  # nchar corresponds to char n-grams that I'll use as labels

In [None]:
from tqdm import tqdm

Collect the bigrams that appear in the dataset. Note their frequency

In [None]:
freq_dict = {}

In [None]:
print('documents')
for document in tqdm(documents):
  document = document.decode()[:-5]
  # for bi in set([document[i:i+nchar] for i in range(0, len(document), nchar)]):
  for bi in set(document):
    if bi not in freq_dict:
      freq_dict[bi] = 0
    freq_dict[bi] += 1

print('summaries')
for document in tqdm(summaries):
  document = document.decode()
  # for bi in set([document[i:i+nchar] for i in range(0, len(document), nchar)]):
  for bi in set(document):
    if bi not in freq_dict:
      freq_dict[bi] = 0
    freq_dict[bi] += 1

documents


100%|██████████| 44972/44972 [00:05<00:00, 7813.35it/s]


summaries


100%|██████████| 44972/44972 [00:01<00:00, 39297.59it/s]


The data is messy, so I want to remove rare bigrams (which correspond to non-English text, links, emojis etc)

In [None]:
sorted_freq_dict = dict(sorted(freq_dict.items(), key=lambda item: item[1]))

In [None]:
cleaned_freq_dict = {}

In [None]:
len(sorted_freq_dict)

2459

In [None]:
for key, freq in sorted_freq_dict.items():
  if freq <4:
    continue
  cleaned_freq_dict[key] = freq

In [None]:
len(cleaned_freq_dict)

705

In [None]:
import random

In [None]:
keys = list(cleaned_freq_dict.keys())
random.shuffle(keys)

In [None]:
keys[:100]

['了',
 '🎃',
 '‑',
 '\x94',
 'ದ',
 'R',
 '👇',
 'ƒ',
 'ು',
 'ಸ',
 'Î',
 '‘',
 '스',
 '！',
 '̈',
 'ʻ',
 'B',
 '😘',
 'ಪ',
 'Q',
 'η',
 'آ',
 '✊',
 '?',
 'Ş',
 'ھ',
 'P',
 'い',
 '「',
 'Ο',
 '👍',
 '\u202c',
 'シ',
 'я',
 '‹',
 '🔥',
 '（',
 '出',
 'φ',
 '年',
 '@',
 'ہ',
 'מ',
 'c',
 '\x91',
 'У',
 '🇸',
 '}',
 'ย',
 'ख',
 'ಶ',
 'â',
 '😢',
 '💗',
 'Г',
 'ಹ',
 '±',
 'Ś',
 'χ',
 'К',
 '기',
 'о',
 'া',
 '¨',
 '—',
 'ة',
 ')',
 'З',
 'Р',
 '😂',
 '×',
 '✌',
 '보',
 '☀',
 'は',
 'ी',
 '🙂',
 '\x95',
 'ー',
 'ま',
 'к',
 'ː',
 '≥',
 'ő',
 'י',
 '🚨',
 'ÿ',
 '̃',
 'く',
 'у',
 'ジ',
 '\ue60e',
 'س',
 'α',
 'ś',
 '리',
 '😁',
 '◀',
 'ಖ',
 'ನ']

In [None]:
vocab = ['<null>', '<start>', '<end>'] + list(keys)  # add start and end labels that I can add to the docs

In [None]:
# dicts to convert back and forth from label indices
stoi = {s:i for i, s in enumerate(vocab)}
itos = {i:s for i, s in enumerate(vocab)}

In [None]:
# functions to encode end decode strings into lists of indices
# encode = lambda x: list(filter(lambda y: y is not None, [stoi.get(s) for s in [x[i:i+nchar] for i in range(0, len(x), nchar)]]))
encode = lambda x: list(filter(lambda y: y is not None, [stoi.get(s) for s in x]))
decode = lambda x: ''.join([itos[i] for i in x])

In [None]:
print(encode("this is only a test this is only a test"))
print(decode(encode("this is only a test this is only a test")))


[137, 506, 641, 548, 400, 641, 548, 400, 311, 615, 248, 308, 400, 383, 400, 137, 669, 548, 137, 400, 137, 506, 641, 548, 400, 641, 548, 400, 311, 615, 248, 308, 400, 383, 400, 137, 669, 548, 137]
this is only a test this is only a test


In [None]:
summaries[0]

b'\xe2\x80\x93 The White House flag had a more tumultuous start to the week than typical, having been lowered to half-staff on Saturday in the wake of Sen. John McCain\'s death, then raised back up on Monday in keeping with the US flag code but to an outcry from the media and others who felt a flag proclamation from President Trump was in order. After lowering the flag once more, that proclamation came Monday afternoon; Reuters describes it as being issued in a "delayed" manner, noting that presidents typically take their cue from Congress when a high-profile lawmaker dies. NPR has President Trump\'s statement: "Despite our differences on policy and politics, I respect Senator John McCain\'s service to our country and, in his honor, have signed a proclamation to fly the flag of the United States at half-staff until the day of his interment," which is Saturday. NPR details the groups who implored the president to make the move, including the American Legion and the veterans group AMVETS

In [None]:
# trim some of the trash in the dataset

desired_docs = []
desired_summaries = []
for i, doc in enumerate(documents):
  if len(doc) > 6 * 10000:  # roughly corresponds to 10000 word articles
    continue
  else:
    desired_docs.append([1] + encode(doc.decode()[:-5]) + [2])
    desired_summaries.append([1] + encode(summaries[i].decode()) + [2])

val_docs, val_summaries = val_ds
desired_val_docs = []
desired_val_summaries = []
for i, doc in enumerate(val_docs):
  if len(doc) > 6*10000:
    continue
  else:
    desired_val_docs.append([1] + encode(doc.decode()[:-5]) + [2])
    desired_val_summaries.append([1] + encode(val_summaries[i].decode()) + [2])

In [None]:
longest_doc = 0
for doc in tqdm(desired_docs):
  doclen = len(doc)
  if doclen > longest_doc:
    longest_doc = doclen

longest_summary = 0
for doc in tqdm(desired_summaries):
  doclen = len(doc)
  if doclen > longest_summary:
    longest_summary = doclen
print()
print(longest_doc, longest_summary)

100%|██████████| 44649/44649 [00:00<00:00, 2427966.44it/s]
100%|██████████| 44649/44649 [00:00<00:00, 2171691.57it/s]


59692 5912





Modified version of decoder only transformer from Andrej Karpathy's video here: https://www.youtube.com/watch?v=kCc8FmEb1nY

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
## Hyperparams
batch_size = 4
enc_length = 2000 #longest_doc / 2
dec_length = 2000
n_head = 4
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
head_size = 4
n_layer = 3
vocab_size = len(vocab)
num_epochs = 1
##

In [None]:
from torch.cuda import device_of
class Head(nn.Module):
    """ one head of self attention"""

    def __init__(self, head_size, is_masked=True) -> None:
        super().__init__()
        self.is_masked = is_masked
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(dec_length, dec_length)))
        if not is_masked:
          # self.enc_to_dec = nn.Linear(enc_length, dec_length)
          pass

    def forward(self, q_input, k_input, v_input):
        # in the encoder you only take in the input sequence
        # in the decoder you take in encoder output as input to key, and value, but query comes from previous
        B, T, C = q_input.shape
        k = self.key(k_input) # B, T, head_size
        q = self.query(q_input) # B, T, head_size

        wei = q @ k.transpose(-2, -1) * C**-0.5  # B, T, T
        # mask using a portion of the tril matrix because we may not be looking at the whole blocksize yet
        if self.is_masked:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # B, T, T
        else:
            # wei = self.enc_to_dec(wei)
            pass
        wei = F.softmax(wei, dim=-1) # B, T, T
        v = self.value(v_input)
        out = wei @ v
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of attention in parallel"""

    def __init__(self, num_heads, head_size, is_masked=True):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, is_masked=is_masked) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, q_input, k_input, v_input):
        # concatenate the outputs over the last dimension, the channel dimension
        out = torch.cat([h(q_input, k_input, v_input) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out


class FeedForward(nn.Module):

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # projection going back into residual pathway
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self, n_embd, n_head, is_masked=True) -> None:
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, is_masked=is_masked)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):

        x = x + self.sa(self.ln1(x), self.ln1(x), self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class Encoder(nn.Module):
    def __init__(self, n_embd, n_head, block_size) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, is_masked=False) for _ in range(n_layer)])

    def forward(self, input):
        B, T = input.shape
        tok_emb = self.token_embedding_table(input)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        input = tok_emb + pos_emb
        input = self.blocks(input)
        return input


class DecoderBlock(nn.Module):
    """Decoder block includes a step that includes the unmasked output of the encoder"""
    def __init__(self, n_embd, n_head, is_masked, pass_along_input) -> None:
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, is_masked=is_masked)
        self.sa2 = MultiHeadAttention(n_head, head_size, is_masked=False)
        self.pass_along_input = pass_along_input
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)
        self.ln4 = nn.LayerNorm(n_embd)

    def forward(self, input):
        x, enc_out = input
        x = x + self.sa(self.ln1(x), self.ln1(x), self.ln1(x))
        x = x + self.sa2(self.ln2(enc_out), self.ln2(enc_out), self.ln2(x))
        x = x + self.ffwd(self.ln3(x))
        if self.pass_along_input:
          return x, enc_out
        return x


class Decoder(nn.Module):
    def __init__(self, n_embd, n_head, block_size) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        blocks = []
        for i in range(n_layer):
          pass_along_input = i < (n_layer - 1)
          blocks.append(DecoderBlock(n_embd, n_head, is_masked=True, pass_along_input=pass_along_input))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input, enc_out):
        B, T = input.shape
        tok_emb = self.token_embedding_table(input)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        input = tok_emb + pos_emb
        self.blocks((input, enc_out))
        return input


class Transformer(nn.Module):

    def __init__(self):
        super().__init__()
        ## encoder start
        # self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # self.position_embedding_table = nn.Embedding(block_size, n_embd)
        # self.lm_head = nn.Linear(n_embd, vocab_size)
        # # self.sa_heads = MultiHeadAttention(4, n_embd // 4)
        # # self.ffwd = FeedForward(n_embd)
        # self.blocks = nn.Sequential(
        #     Block(n_embd, 4),
        #     Block(n_embd, 4),
        #     Block(n_embd, 4),
        # )
        self.encoder = Encoder(n_embd, n_head, enc_length)
        self.decoder = Decoder(n_embd, n_head, dec_length)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, input_seq, idx, targets=None):
        # B, T = idx.shape
        # tok_emb = self.token_embedding_table(x)  # B, T, C
        # pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # T, C
        # x = tok_emb + pos_emb # B, T, C
        # # x = self.sa_head(x)  # B, T, C apply one head of self attention
        # # x = self.ffwd(x)
        # x = self.blocks(x)
        # logits = self.lm_head(x)  # B, T, vocab_size
        enc_out = self.encoder(input_seq)
        x = self.decoder(idx, enc_out)
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    # def generate(self, input, idx, max_new_tokens):
    #     #idx is (B, T) array of indices in the current context
    #     for t in range(max_new_tokens):
    #         idx_cond = idx[:, -dec_length:]  # only the last block_size chars at most
    #         # get the predictions
    #         logits, loss = self(input, idx_cond) # logits has different dimensions during training and generation in order to calculate loss
    #         logits = logits[:, t, :] # becomes (B, C) for the last timestep only for each chunk in batch
    #         # convert logits to probs
    #         probs = F.softmax(logits, dim=1)
    #         # sample from the distribution
    #         idx_next = torch.multinomial(probs, num_samples=1)
    #         idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    #     return idx
    def generate(self, input, idx, max_new_tokens):
      #idx is (B, T) array of indices in the current context
      for t in range(max_new_tokens - 1):
          idx_cond = idx[:, -dec_length:]  # only the last block_size chars at most
          # get the predictions
          logits, loss = self(input, idx_cond) # logits has different dimensions during training and generation in order to calculate loss
          logits = logits[:, t, :] # becomes (B, C) for the last timestep only for each chunk in batch
          # convert logits to probs
          probs = F.softmax(logits, dim=1)
          # sample from the distribution
          idx_next = torch.multinomial(probs, num_samples=1)
          idx[0, t + 1] = idx_next
          # idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
      return idx


In [None]:
!nvidia-smi

Tue Jan 31 08:56:57 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   65C    P0    31W /  70W |    312MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
model = Transformer().to(device)


In [None]:
!pip install torchsummary

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from torchsummary import summary

In [None]:
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

0.833548 M parameters


In [None]:
train_ds = list(zip(desired_docs, desired_summaries))
random.shuffle(train_ds)
val_ds = list(zip(desired_val_docs, desired_val_summaries))
random.shuffle(val_ds)

In [None]:
def get_batch(i, split):
  data = train_ds if split == 'train' else val_ds
  xb, yb = zip(*data[i*batch_size : (i+1)*batch_size])
  xb = list(xb)
  yb = list(yb)
  targets = []
  for i in range(len(xb)):
    xb[i] += [0]*(enc_length - len(xb[i]))
    xb[i] = xb[i][:enc_length]
    yb[i] += [0]*(dec_length - len(yb[i]))
    yb[i] = yb[i][:dec_length]
    tmp = yb[i][1:dec_length+1]
    targets.append(tmp + [0]*(dec_length - len(tmp)))
  return torch.tensor(xb).to(device), torch.tensor(yb).to(device), torch.tensor(targets).to(device)

In [None]:
a, b, c = get_batch(0, 'train')

In [None]:
print(a.shape, b.shape, c.shape)

torch.Size([4, 2000]) torch.Size([4, 2000]) torch.Size([4, 2000])


In [None]:
num_batches = len(train_ds) // batch_size

In [None]:
num_val_batches = len(val_ds) // batch_size

In [None]:
print(model)

Transformer(
  (encoder): Encoder(
    (token_embedding_table): Embedding(708, 64)
    (position_embedding_table): Embedding(2000, 64)
    (lm_head): Linear(in_features=64, out_features=708, bias=True)
    (blocks): Sequential(
      (0): Block(
        (sa): MultiHeadAttention(
          (heads): ModuleList(
            (0): Head(
              (key): Linear(in_features=64, out_features=16, bias=False)
              (query): Linear(in_features=64, out_features=16, bias=False)
              (value): Linear(in_features=64, out_features=16, bias=False)
            )
            (1): Head(
              (key): Linear(in_features=64, out_features=16, bias=False)
              (query): Linear(in_features=64, out_features=16, bias=False)
              (value): Linear(in_features=64, out_features=16, bias=False)
            )
            (2): Head(
              (key): Linear(in_features=64, out_features=16, bias=False)
              (query): Linear(in_features=64, out_features=16, bias=False

In [None]:
logits, loss = model(a, b, c)

In [None]:
import numpy as np

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        max_batches = num_batches if split == 'train' else num_val_batches
        for i, k in enumerate(np.random.randint(max_batches, size=eval_iters)):
            X, Y, targets = get_batch(k, split)
            logits, loss = model(X, Y, targets=targets)
            losses[i] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
def generate(model, input, idx, max_new_tokens):
    #idx is (B, T) array of indices in the current context
    for t in range(max_new_tokens - 1):
        idx_cond = idx[:, -dec_length:]  # only the last block_size chars at most
        # get the predictions
        logits, loss = model(input, idx_cond) # logits has different dimensions during training and generation in order to calculate loss
        logits = logits[:, t, :] # becomes (B, C) for the last timestep only for each chunk in batch
        # convert logits to probs
        probs = F.softmax(logits, dim=1)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)
        idx[0, t + 1] = idx_next
        # idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [None]:
def get_qualitative_result():
  test_text = '''Neanderthals may have lived in larger groups than previously believed, hunting massive elephants that were up to three times bigger than those of today, according to a new study.

The researchers reached their conclusions, published in the journal Science Advances on Wednesday, based on examinations of the 125,000-year-old skeletal remains of straight-tusked elephants found near Halle in central Germany.

The bones of about 70 elephants from the Pleistocene era were discovered in the 1980s in a huge coal quarry that has since been converted into an artificial lake.

Elephants of the time were much larger than the woolly mammoth and three times the size of the present-day Asian elephant: an adult male could weigh up to 13 tonnes.

“Hunting these giant animals and completely butchering them was part of Neanderthal subsistence activities at this location,” Wil Roebroeks, a co-author of the study, told AFP.

“This constitutes the first clearcut evidence of elephant-hunting in human evolution,” said Roebroeks, a professor of archeology at Leiden University in the Netherlands.

The study suggests that the Neanderthals who lived in the area for 2,000 to 4,000 years were less mobile and formed social units “substantially larger than commonly envisaged”.

“Neanderthals were not simple slaves of nature, original hippies living off the land,” Roebroeks said.

“They were actually shaping their environment, by fire … and also by having a big impact on the biggest animals that were around in the world at that time.”

The researchers determined the elephants had been hunted – and not just scavenged – because of the age and sex profile of the remains found in the quarry.

Most of them were males and there were few young or old ones.

“It’s a typical selection made by hunters who went for the biggest prey,” Roebroeks said.
'''
  encoded = encode(test_text)
  encoded += [0] * (2000 - len(encoded))
  idx = torch.zeros((1, 2000), dtype=torch.long)
  idx[0, 0] = 1  # set the initial value to '\n'
  encoded = torch.tensor(encoded).unsqueeze(0).to(device)
  idx = idx.to(device)
  result = generate(model, encoded, idx, max_new_tokens=2000)
  print(decode(result[0].tolist()))

In [None]:
for epoch in range(5):
  print(f"start epoch {epoch}")
  for batch_idx in range(num_batches):
    if batch_idx % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step {batch_idx}: train loss {losses['train']:.4f}, val_loss {losses['val']:.4f} ")

    xb, yb, targets = get_batch(batch_idx, 'train')

    logits, loss = model(xb, yb, targets=targets)
    optimizer.zero_grad(set_to_none=None)
    loss.backward()
    optimizer.step()
  get_qualitative_result()



start epoch 0
Step 0: train loss 6.9674, val_loss 6.9697 
Step 500: train loss 1.7730, val_loss 1.7173 
Step 1000: train loss 1.6869, val_loss 1.6699 
Step 1500: train loss 1.6644, val_loss 1.6867 
Step 2000: train loss 1.6482, val_loss 1.6111 
Step 2500: train loss 1.6407, val_loss 1.6255 
Step 3000: train loss 1.6507, val_loss 1.6232 
Step 3500: train loss 1.6712, val_loss 1.6261 
Step 4000: train loss 1.6382, val_loss 1.6266 
Step 4500: train loss 1.6310, val_loss 1.5963 
Step 5000: train loss 1.6646, val_loss 1.5827 
Step 5500: train loss 1.6454, val_loss 1.6157 
Step 6000: train loss 1.6370, val_loss 1.6013 
Step 6500: train loss 1.6425, val_loss 1.6104 
Step 7000: train loss 1.6141, val_loss 1.6323 
Step 7500: train loss 1.6335, val_loss 1.6148 
Step 8000: train loss 1.6082, val_loss 1.6034 
Step 8500: train loss 1.6326, val_loss 1.6216 
Step 9000: train loss 1.6362, val_loss 1.6044 
Step 9500: train loss 1.6050, val_loss 1.6112 
Step 10000: train loss 1.6035, val_loss 1.6243 
St

In [None]:
get_qualitative_result()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [None]:
for epoch in range(5):
  print(f"start epoch {epoch}")
  for batch_idx in range(num_batches):
    if batch_idx % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step {batch_idx}: train loss {losses['train']:.4f}, val_loss {losses['val']:.4f} ")

    xb, yb, targets = get_batch(batch_idx, 'train')

    logits, loss = model(xb, yb, targets=targets)
    optimizer.zero_grad(set_to_none=None)
    loss.backward()
    optimizer.step()
  get_qualitative_result()
