In [4]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import math
import pandas as pd
import random
import wandb

from torch import nn, optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from datasets import load_dataset
from transformers import AutoTokenizer
from staticvectors import StaticVectors
from datetime import datetime
from tqdm import tqdm

In [5]:
# set appropriate device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')


class PositionalEncoding(nn.Module):
  def __init__(self, embed_dim):
    super().__init__()

    self.embed_dim = embed_dim

  def forward(self, x):
    # batch size and sequence length bookkeeping
    batch_size = x.shape[0]
    seq_len = x.shape[1]

    # initialize positional encoding
    pe = torch.zeros(1, seq_len, self.embed_dim).to(x.device)

    # calculate encoding term
    pos = torch.arange(0, seq_len, dtype=torch.float32)
    enc = torch.exp((-math.log(10000.0)) * (torch.arange(0, self.embed_dim, step=2, dtype=torch.float32) / self.embed_dim))

    # calculate positional encoding
    prod = torch.outer(pos, enc)
    pe[0, :, 0::2] = torch.sin(prod)
    pe[0, :, 1::2] = torch.cos(prod)
    pe = pe.expand(batch_size, -1, -1)

    # apply as residual
    x = x + pe
    return x


class LanguageTransformer(nn.Module):

  def __init__(
    self,
    vocab_size,
    embed_dim,
    num_layers,
    num_heads,
    word_emb=None
  ):
    super().__init__()

    # learned (or given) vector embeddings
    if word_emb is not None:
      self.token_embedding = nn.Embedding.from_pretrained(word_emb)

      # freeze embeddings besides custom tokens
      self.token_embedding.weight.data[:].requires_grad_(False)
    else:
      self.token_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)

    # positional encodings
    self.pos_enc = PositionalEncoding(embed_dim=embed_dim)

    # prepare single transformer layer with multiheaded attention
    transformer_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)

    # create transformer with multiple layers
    self.transformer = nn.TransformerEncoder(encoder_layer=transformer_layer, num_layers=num_layers)

    # vocab classifier
    self.classifier = nn.Sequential(
      nn.Linear(in_features=embed_dim, out_features=1024),
      nn.ReLU(),
      nn.Linear(in_features=1024, out_features=vocab_size)
    )
    
  def forward(self, seq):
    # get lengths of sequences
    seq_len = seq.shape[1]

    # embed sequence w poisitional encodings
    seq_embed = self.token_embedding(seq)
    seq_embed = self.pos_enc(seq_embed)

    # generate custom prefixed causal mask
    mask = torch.zeros(seq_len, seq_len, dtype=torch.float32).to(device)
    seq_out = self.transformer(src=seq_embed, mask=mask)

    # classify target sequence output into target vocabulary
    out = self.classifier(seq_out)
    return out

In [6]:
# bert tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [7]:
# word2vec embeddings (300 dim)
word2vec = StaticVectors("neuml/word2vec")
word2vec_embeddings = torch.tensor(word2vec.embeddings(tokenizer.get_vocab())).type(torch.float32).to(device)

In [37]:
train_config = {
    'bs': 16,
    'lr': 0.001,
    'weight_decay': 0.00001,
    'max_epochs': 10
}

# model configuration
model_config = {
    'emb_dim': word2vec_embeddings.shape[1],
    'num_layers': 4,
    'num_heads': 4
}

vocab_len = len(tokenizer.get_vocab())

In [38]:
# load huggingface dataset
poetry_dataset = load_dataset("merve/poetry")

# tokenize
dataset_tokens = tokenizer(list(poetry_dataset['train']['content']), padding='max_length', max_length=2048, truncation=True, return_tensors="pt", add_special_tokens=True)
inputs = dataset_tokens.input_ids[:, :-1]
labels = dataset_tokens.input_ids[:, 1:]

tokenized_dataset = torch.utils.data.TensorDataset(inputs, labels)

# quick check on decode/encode
print(tokenizer.decode(tokenized_dataset[0][0], skip_special_tokens=True))

# create dataloader
train_loader = DataLoader(tokenized_dataset, batch_size=train_config['bs'], num_workers=4, shuffle=True, drop_last=True)

Repo card metadata block was not found. Setting CardData to empty.


Let the bird of loudest lay On the sole Arabian tree Herald sad and trumpet be, To whose sound chaste wings obey. But thou shrieking harbinger, Foul precurrer of the fiend, Augur of the fever ' s end, To this troop come thou not near. From this session interdict Every fowl of tyrant wing, Save the eagle, feather ' d king ; Keep the obsequy so strict. Let the priest in surplice white, That defunctive music can, Be the death - divining swan, Lest the requiem lack his right. And thou treble - dated crow, That thy sable gender mak ' st With the breath thou giv ' st and tak ' st, ' Mongst our mourners shalt thou go. Here the anthem doth commence : Love and constancy is dead ; Phoenix and the Turtle fled In a mutual flame from hence. So they lov ' d, as love in twain Had the essence but in one ; Two distincts, division none : Number there in love was slain. Hearts remote, yet not asunder ; Distance and no space was seen ' Twixt this Turtle and his queen : But in them it were a wonder. So bet

In [39]:
# create language model
model = LanguageTransformer(
    vocab_size=tokenizer.vocab_size,
    embed_dim=model_config['emb_dim'],
    num_layers=model_config['num_layers'],
    num_heads=model_config['num_heads'],
    word_emb=word2vec_embeddings
    ).to(device)

# quick dry run
bs = train_config["bs"]
seq_len = random.randint(0, 100)

seq = torch.randint(0, vocab_len, (bs, seq_len)).to(device)
out = model(seq)

assert out.shape == (bs, seq_len, vocab_len)
print("Passed dry run!")

Passed dry run!


In [None]:
# set up wandb
# now = datetime.now()
# run_name = "dlm-" + now.strftime("%Y_%m_%d_%H_%m")

# initialize wandb session
# wandb.login()
# wandb.init(project="diffusion-language-model", name=run_name, config=train_config)

# custom optimizer
optimizer = optim.AdamW(model.parameters(), lr=train_config['lr'], weight_decay=train_config['weight_decay']) 

# define warmup and cooldown epochs
warmup_epochs = int(train_config['max_epochs'] / 10)
cooldown_epochs = train_config['max_epochs'] - warmup_epochs

# epoch length
epoch_len = len(dataset_tokenized) // train_config['bs']

# construct linear warmup and cosine annealing scheduler
linear = LinearLR(optimizer, start_factor=0.25, end_factor=1.0, total_iters=warmup_epochs*epoch_len)
cosine = CosineAnnealingLR(optimizer, T_max=cooldown_epochs*epoch_len, eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[linear, cosine], milestones=[warmup_epochs*epoch_len])

# set up cross entropy loss for transformer output
criterion = nn.CrossEntropyLoss(ignore_index=0)

# set up progress bar
pbar = tqdm(total=(train_config['max_epochs'])*epoch_len, desc="Training Iterations", unit="batch")

# main training loop
# iteration = 0
for epoch in range(train_config['max_epochs']):
    # set model to train
    model.train()

    for batch_idx, batch in enumerate(train_loader):
        # log lr for each epoch
        # wandb.log({'learning-rate': scheduler.get_last_lr()[0]}, step=iteration)

        inputs = batch[0].to(device)
        labels = batch[1].to(device)
        break

        # run through model
        out = model(inputs)
        print(out.shape)

        # compute loss
        loss = criterion(out, output_tokens)
        print(loss.item())

        loss.backward()
        optimizer.step()

        optimizer.zero_grad()

        # log loss/train and accuracy/train per batch
        # wandb.log({"loss": loss.item()}, step=iteration)
        
        pbar.update(1)
        # iteration += 1

        # step through scheduler
        scheduler.step()

# wandb.finish()
pbar.close()