Write train loop

In [5]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [24]:
from enum import Enum
from typing import Any, Callable

from datasets import load_from_disk
from jax import numpy as jnp
from jax import random
from flax import core
from flax import linen as nn
from flax import struct
import optax
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler
from transformers import PreTrainedTokenizerFast

In [60]:
class JaxRandomSampler(Sampler):
    
    def __init__(self, data_source, key=None):
        self.data_source = data_source
        self.key = key
        
    def __iter__(self):
        yield from random.permutation(self.key, len(self)).tolist()
        
    def __len__(self):
        return len(self.data_source)
    
    
def jax_collate(batch):
    keys = batch[0].keys()
    ret = {}
    for key in keys:
        vals = [item[key] for item in batch]
        try:
            ret[key] = jnp.stack(vals)
        except TypeError:
            ret[key] = vals
    return ret
    
    
class JaxDataLoader(DataLoader):
    
    def __init__(self, dataset, **kwargs):
        sampler = JaxRandomSampler(dataset)
        if 'collate_fn' not in kwargs:
            kwargs['collate_fn'] = jax_collate
        super().__init__(dataset, sampler=sampler, shuffle=False, **kwargs)
        
    def set_key(self, key):
        self.sampler.key = key
         
    def make_iter(self, key):
        self.sampler.key = key
        return iter(self)

        

def dataloader_iter_factory(dataset, **dataloader_kwargs):
    dataloader = JaxDataLoader(dataset, **dataloader_kwargs)
    
    def dataloader_iter(key):
        return dataloader.make_iter(key)
    
    return dataloader_iter

In [8]:
class State(struct.PyTreeNode):
    params: core.FrozenDict[str, Any] = None
    opt_state: optax.OptState = None
    iteration: int = struct.field(pytree_node=False, default=0)
    epoch: int = struct.field(pytree_node=False, default=0)
    seed: int = struct.field(pytree_node=False, default=0)
    rng: Any = struct.field(pytree_node=False, default=None)
    dataloader_rng: Any = struct.field(pytree_node=False, default=None)
    max_epochs: int = struct.field(pytree_node=False, default=1)
    batch: Any = struct.field(pytree_node=False, default=None)
    output: Any = struct.field(pytree_node=False, default=None)
    metrics: Any = struct.field(pytree_node=False, default=None)
    times: Any = struct.field(pytree_node=False, default=None)

    def apply_gradients(self, *, grads, tx, **kwargs):
        updates, new_opt_state = tx.update(
            grads, self.opt_state, self.params)
        new_params = optax.apply_updates(self.params, updates)
        return self.replace(
            params=new_params,
            opt_state=new_opt_state,
            **kwargs,
        )
    
    def next_data_iter(self, data_iter_fn):
        key, subkey = random.split(self.dataloader_rng)
        self.replace(dataloader_rng=key)
        return data_iter_fn(subkey)
    
    @classmethod
    def create(cls, seed):
        key = random.PRNGKey(self.state.seed)
        key, subkey = random.split(key)
        return cls(seed=seed, rng=key, dataloader_rng=subkey)

In [10]:
from collections import defaultdict


class Events(Enum):
    STARTED = 'started'
    EPOCH_STARTED = 'epoch_started'
    GET_BATCH_STARTED = 'get_batch_started'
    GET_BATCH_COMPLETED = 'get_batch_completed'
    ITERATION_STARTED = 'iteration_started'
    ITERATION_COMPLETED = 'iteration_completed'
    DATALOADER_STOP_ITERATION = 'dataloader_stop_iteration'
    EXCEPTION_RAISED = 'exception_raised'
    TERMINATE = 'terminate'
    EPOCH_COMPLETED = 'epoch_completed'
    COMPLETED = 'completed'
    
    
class EngineTerminateException(Exception):
    pass
    


class Engine:
    
    def __init__(self, process_function, seed=0):
        self.event_handlers = defaultdict(list)
        self.state = State.create(seed)
        self.process_function = process_function
        self.setup_default_handlers()
        
    def setup_default_handlers(self):
        pass
        
    def set_state_attr(self, **attr_setting):
        self.state = self.state.replace(**attr_setting)
        
    def increment_state_attr(self, attr):
        self.set_state_attr(attr, getattr(self.state, attr) + 1)
    
    def run(self, data_iter_fn, max_epochs):
        # TODO reset state counters at the beginning
        try:
            self.set_state_attr(max_epochs=max_epochs)
            self.fire_event(Events.STARTED)
            while self.state.epochs < max_epochs:
                self.increment_state_attr('epoch')
                self.fire_event(Events.EPOCH_STARTED)
                data_iter = self.state.next_data_iter(data_iter_fn)
                while True:
                    self.set_state_attr(batch=None, output=None)
                    try:
                        self.fire_event(Events.GET_BATCH_STARTED)
                        self.set_state_attr(batch=next(data_iter))
                        self.fire_event(Events.GET_BATCH_COMPLETED)
                    except StopIteration:
                        self.fire_event(Events.DATALOADER_STOP_ITERATION)
                        continue
                    self.increment_state_attr('iteration')
                    self.fire_event(Events.ITERATION_STARTED)
                    self.set_state_attr(output=self.process_function(self, self.state.batch))
                    self.fire_event(Events.ITERATION_COMPLETED)
                self.fire_event(Events.EPOCH_COMPLETED)
        except EngineTerminateException:
            self.fire_event(Events.TERMINATE)
        self.fire_event(Events.COMPLETED)
        return self.state
        
    def fire_event(self, event, *event_args, **event_kwargs):
        handlers = event_handlers[event]
        for f in handlers:
            self.state = f(self.state, *event_args, **event_kwargs)
    
    def add_event_handler(self, event, f):
        self.event_handlers[event].append(f)
    
    def on(self, event):

        def decorator(f):
            self.add_event_handler(event, f)
            return f

        return decorator

In [12]:
class ZeroLayerTransformer(nn.Module):
    vocab_size: int
    embed_dim: int
    
    @nn.compact
    def __call__(self, input_ids):
        """
        input_ids will be a batch of input ids shape (n_examples, max_seq_len)
        """
        embedded = nn.Embed(self.vocab_size, self.embed_dim, name='embedding_matrix')(input_ids)
        unembedded = nn.Dense(vocab_size, use_bias=False, name='unembedding_matrix')(embedded)
        logits = nn.softmax(unembedded)
        return logits


In [11]:
dataset = load_from_disk('tokenized_wiki_dataset')
dataloader = JaxDataLoader(dataset, batch_size=32)
tokenizer = PreTrainedTokenizerFast(tokenizer_file='tokenizer-wiki.json')

In [61]:
dataloader_iter_fn = dataloader_iter_factory(dataset, batch_size=32)

NameError: name 'key' is not defined

In [56]:
it = dataloader_iter_fn(key1)

In [59]:
next(it)['input_ids']

[[3031,
  15351,
  485,
  447,
  441,
  4131,
  11578,
  780,
  438,
  1780,
  15531,
  978,
  626,
  1248,
  447,
  441],
 [497,
  441,
  1472,
  979,
  3911,
  2613,
  515,
  438,
  2971,
  465,
  441,
  850,
  456,
  3860,
  470,
  2053],
 [4440,
  5471,
  760,
  1048,
  12601,
  1161,
  259,
  13249,
  445,
  927,
  447,
  1856,
  445,
  1348,
  2114,
  4377],
 [3596,
  447,
  441,
  2336,
  5064,
  285,
  655,
  7897,
  1499,
  11884,
  507,
  438,
  12342,
  5182,
  3247,
  465],
 [16182,
  683,
  1268,
  533,
  537,
  604,
  440,
  528,
  442,
  2324,
  447,
  1209,
  445,
  441,
  16608,
  257],
 [2385,
  508,
  1780,
  16816,
  508,
  495,
  1054,
  4245,
  538,
  10916,
  470,
  12019,
  4719,
  17806,
  277,
  1579],
 [1248,
  528,
  442,
  9183,
  470,
  14882,
  441,
  4368,
  465,
  441,
  13156,
  576,
  445,
  5352,
  438,
  917],
 [3241,
  500,
  951,
  11865,
  3432,
  3908,
  528,
  442,
  16870,
  488,
  749,
  1153,
  1041,
  445,
  1437,
  756],
 [470,
  1389,
  9

In [31]:
VOCAB_SIZE =  tokenizer.vocab_size
EMBED_DIM = 12*64
lr = 1e-2

model = ZeroLayerTransformer(vocab_size=VOCAB_SIZE, embed_dim=EMBED_DIM)

In [33]:
sample_data

<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x7f8ca7da54c0>

In [40]:
it = dataloader.make_iter(key1)

In [42]:
next(it)

<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x7f84f03eeca0>

In [45]:
it = dataloader.make_iter(key1)
for a in it:
    pass

In [47]:
dir(a)

['_IterableDataset_len_called',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__next__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_base_seed',
 '_collate_fn',
 '_dataset',
 '_dataset_fetcher',
 '_dataset_kind',
 '_drop_last',
 '_index_sampler',
 '_next_data',
 '_next_index',
 '_num_workers',
 '_num_yielded',
 '_persistent_workers',
 '_pin_memory',
 '_pin_memory_device',
 '_prefetch_factor',
 '_profile_name',
 '_reset',
 '_sampler_iter',
 '_shared_seed',
 '_timeout',
 'next']

In [39]:
key1, key2 = random.split(random.PRNGKey(0), 2)
sample_data = next(dataloader.make_iter(key1))

params = model.init(key2, sample_data)

AttributeError: '_SingleProcessDataLoaderIter' object has no attribute 'dtype'

In [None]:
VOCAB_SIZE =  tokenizer.vocab_size
EMBED_DIM = 12*64
lr = 1e-2


model = ZeroLayerTransformer(vocab_size=VOCAB_SIZE, embed_dim=EMBED_DIM)
tx = optax.sgd(learning_rate=lr)


def cross_entropy_loss(params, batch):
    # Define the cross entropy loss for a single pair (x,y)
    def cross_entropy_error(x, y):
        # We don't have the label for the last token
        probs = model.apply(params, x)
        
        
        
        
        return jnp.inner(y-pred, y-pred) / 2.0
    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)


def train_step(engine, batch):
    
    
    

In [None]:
# possible usage

vocab_size = 20
embed_dim = 10

model = ZeroLayerTransformer(vocab_size, embed_dim)


def update(state, batch):
    logits = model.apply(state.params, batch)
    

trainer = Engine(update)


trainer.on(Events.STARTED)
def setup_optimizer(state):
    pass


trainer.on(Events.STARTED)
def init_params(state):
    pass


In [None]:
from datasets import load_from_disk
from torch.utils.data import dataloader


dataset = load_from_disk('tokenized_wiki_dataset')
dl = dataloader(dataset)
trainer.run(dl, max_epochs=10)

In [87]:
from datasets import load_from_disk
from torch.utils.data import DataLoader


dataset = load_from_disk('tokenized_wiki_dataset')
dl = JaxDataLoader(dataset.select(range(5)))

In [63]:
(dl.dataset)

Dataset({
    features: ['input_ids'],
    num_rows: 5
})

In [69]:
def test(a, b=2, *, c=4):
    print(a, b, c)

In [70]:
test(1, 2, 3)

TypeError: test() takes from 1 to 2 positional arguments but 3 were given

In [71]:
test(1, 2, c=3)

1 2 3


In [99]:
dl = JaxDataLoader(dataset.select(range(5)))

In [100]:
next(iter(dl.sampler))

1

In [101]:
for ex in dl:
    pass

In [102]:
ex

{'input_ids': [tensor([11697]),
  tensor([1211]),
  tensor([507]),
  tensor([438]),
  tensor([2010]),
  tensor([5383]),
  tensor([470]),
  tensor([2799]),
  tensor([561]),
  tensor([507]),
  tensor([442]),
  tensor([1424]),
  tensor([655]),
  tensor([465]),
  tensor([4059]),
  tensor([470])]}

In [105]:
for ex in dl:
    pass

In [106]:
ex

{'input_ids': [tensor([1211]),
  tensor([5968]),
  tensor([502]),
  tensor([441]),
  tensor([13317]),
  tensor([465]),
  tensor([441]),
  tensor([992]),
  tensor([445]),
  tensor([636]),
  tensor([544]),
  tensor([5399]),
  tensor([476]),
  tensor([522]),
  tensor([19074]),
  tensor([619])]}

In [43]:
dir(dl)

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '_IterableDataset_len_called',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_get_iterator',
 '_get_shared_seed',
 '_index_sampler',
 '_is_protocol',
 '_iterator',
 'batch_sampler',
 'batch_size',
 'check_worker_number_rationality',
 'collate_fn',
 'dataset',
 'drop_last',
 'generator',
 'multiprocessing_context',
 'num_workers',
 'persistent_workers',
 'pin_memory',
 'pin_memory_device',
 'prefetch_factor',
 'sampler',
 'timeout',
 'worker_i

In [None]:
next(iter(dl))

In [31]:
next(iter(dl))

{'input_ids': [tensor([11697]),
  tensor([1211]),
  tensor([507]),
  tensor([438]),
  tensor([2010]),
  tensor([5383]),
  tensor([470]),
  tensor([2799]),
  tensor([561]),
  tensor([507]),
  tensor([442]),
  tensor([1424]),
  tensor([655]),
  tensor([465]),
  tensor([4059]),
  tensor([470])]}

In [33]:
dataset[:5]

{'input_ids': [[11697,
   1211,
   507,
   438,
   2010,
   5383,
   470,
   2799,
   561,
   507,
   442,
   1424,
   655,
   465,
   4059,
   470],
  [5544,
   283,
   713,
   2025,
   5384,
   619,
   445,
   733,
   443,
   304,
   620,
   3867,
   465,
   16939,
   447,
   11697],
  [1211,
   5968,
   502,
   441,
   13317,
   465,
   441,
   992,
   445,
   636,
   544,
   5399,
   476,
   522,
   19074,
   619],
  [445,
   3154,
   467,
   17578,
   445,
   470,
   7193,
   1354,
   447,
   509,
   438,
   8146,
   1500,
   496,
   3287,
   2799],
  [445,
   2967,
   497,
   441,
   17351,
   3023,
   1500,
   465,
   441,
   2010,
   10693,
   445,
   544,
   507,
   2797,
   2123]]}

In [2]:
Event.STARTED

<Event.STARTED: 'started'>