In [95]:
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader

batch_size = 32
max_length = 128

dataset = load_dataset('glue', 'sst2')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def tokenize_examples(examples):
    return tokenizer(examples['sentence'], truncation=True, padding='max_length', max_length=max_length)

tokenized_dataset = dataset.map(tokenize_examples, batched=True)

# Used to create batches of tokens
def data_collator(features): 
    input_ids = [f['input_ids'] for f in features]
    return torch.tensor(input_ids)

train_dataloader = DataLoader(tokenized_dataset['train'].shuffle(seed=42), batch_size=batch_size, collate_fn=data_collator)
val_dataloader = DataLoader(tokenized_dataset['validation'].shuffle(seed=42), batch_size=batch_size, collate_fn=data_collator)

Map: 100%|██████████| 67349/67349 [00:02<00:00, 25322.64 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 26540.83 examples/s]
Map: 100%|██████████| 1821/1821 [00:00<00:00, 27462.74 examples/s]


In [118]:
import tqdm
import torch
from model.torch_impl.language_model import TorchLanguageModel, LanguageModelConfig

def do_train():
    model = TorchLanguageModel(LanguageModelConfig(**{
        'vocab_size': tokenizer.vocab_size,
        'context_length': max_length,
        'embedding_dim': 8,
        'num_decoder_layers': 3,
        'num_heads': 2,
        'dim_feedforward': 32,
        'dropout': 0.1,
    }))
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(10):
        model.train()
        iters = 0
        for batch in tqdm.tqdm(train_dataloader):
            inputs = batch[:, :-1]
            labels = batch[:, 1:]
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output.view(-1, tokenizer.vocab_size), labels.flatten())
            loss.backward()
            optimizer.step()
            iters += 1
            if iters > 50:
                break
        
        model.eval()
        iters = 0
        with torch.no_grad():
            total_correct = 0
            for batch in val_dataloader:
                inputs = batch[:, :-1]
                labels = batch[:, 1:]
                output = model(inputs)
                _, predicted = torch.max(output, dim=2)
                total_correct += (predicted == labels).sum().item()
                iters += 1
                if iters > 10:
                    break
            accuracy = total_correct / len(val_dataloader.dataset)
            print(f'Epoch {epoch+1}, Validation Accuracy: {accuracy:.4f}')

In [119]:
# TODO: not yet certain that the model is working as intended or that training is well behaved

do_train()

  2%|▏         | 50/2105 [00:27<18:57,  1.81it/s]


Epoch 1, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:27<19:01,  1.80it/s]


Epoch 2, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:28<19:17,  1.78it/s]


Epoch 3, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:27<19:08,  1.79it/s]


Epoch 4, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:27<18:56,  1.81it/s]


Epoch 5, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:28<19:19,  1.77it/s]


Epoch 6, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:28<19:26,  1.76it/s]


Epoch 7, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:28<19:27,  1.76it/s]


Epoch 8, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:28<19:22,  1.77it/s]


Epoch 9, Validation Accuracy: 41.4702


  2%|▏         | 50/2105 [00:28<19:36,  1.75it/s]


Epoch 10, Validation Accuracy: 41.4702
