In [24]:

#!/usr/bin/env python3
from __future__ import annotations
import torch
import torch.optim 
import torch.nn
import torch.nn.functional
import wandb
import einops
# import torch # use pytorch for now.
from typing import *
from dataclasses import dataclass, is_dataclass

In [9]:
def ceil_div(num : int, denom : int) -> int:
    """Saw interesting code that uses 'divmod'"""
    q, r = divmod(num, denom)
    return q + bool(r)
    
def get_k_from_iter(iter : Iterator, k : int):
    """Get k elements from iterator 'iter' """
    out = []
    for _ in range(k):
        try:
            out.append(next(iter))
        except StopIteration:
            break
    return out

def test():
    k = [1, 2, 3, 4, 5].__iter__()
    print(get_k_from_iter(k.__iter__(), 2))
    print(get_k_from_iter(k.__iter__(), 2))
test()

[1, 2]
[3, 4]


In [25]:
@dataclass
class Config:
    project : str # project this config belongs to.
    nepochs : int 
    nbatches : int # number of batches per SGD step. 
    ministep_size : int # size of ministep in each batch
    

In [26]:

DataClass = NewType('DataClass', Any)

class DataCollator:
    def __call__(self, data : List[DataClass]) -> DataClass:
        raise NotImplementedError

class DefaultDataCollator(DataCollator):
    def __call__(self, data : List[DataClass]) -> DataClass:
        assert len(data) > 0
        fs = data[0].fields()
        collated = {} # collated data

        for f in fs:
            collated[f.name] = torch.stack([getattr(d, f.name) for d in data])
        return DataClass(**collated)


In [27]:

class Dataset:
    def mk_iterator(self) -> Iterator[DataClass]:
        raise NotImplementedError

class IterableDataset(Dataset):
    def __init__(self, data : Iterable[DataClass]):
        self.data = data

    def mk_iterator(self) -> Iterator[DataClass]:
        return iter(self.data)


In [28]:

# partially ordered
@dataclass
class ValidationMetric:
    epoch_loss : float 
    def is_strictly_worse_than(self, other : ValidationMetric) -> Optional[bool]:
        """return None if incomparable, True if self is better, False if other is better"""
        fs = self.asdict()
        otherfs = other.asdict()
        assert fs.keys() == otherfs.keys()
        return all([getattr(self, f.name) >= getattr(other, f.name) for f in fs])



In [29]:

@dataclass
class TrainingState:
    model : torch.nn.Module 
    batch_loss : torch.tensor
    epoch_loss : float 
    bix : int 
    eix : int 
    data_iter : Iterator[DataClass]
    metrics : List[ValidationMetric]



In [30]:

class Evaluator:
    def __call__(self,  dataset: Dataset, state: TrainingState) -> ValidationMetric:
        raise NotImplementedError
        
class EpochLossEvaluator(Evaluator):
    def __call__(self, dataset: Dataset, state: TrainingState) -> ValidationMetric:
        return ValidationMetric(epoch_loss=state.epoch_loss)

In [31]:


class TrainingCallback:
    def before_train(self, config : Config, state : TrainingState) -> TrainingCommand:
        return TrainingCommand()
    def before_batch(self, config : Config, state : TrainingState) -> TrainingCommand:
        return TrainingCommand()
    def before_mini_step(self, config : Config, state : TrainingState) -> TrainingCommand:
        return TrainingCommand()
    def after_mini_step(self, config : Config, state : TrainingState) -> TrainingCommand:
        return TrainingCommand()
    def after_batch(self) -> TrainingCommand:
        return TrainingCommand()
    def before_epoch(self, config : Config, state : TrainingState) -> TrainingCommand:
        return TrainingCommand()
    def after_epoch(self) -> TrainingCommand:
        return TrainingCommand()



In [32]:
class TrainingLoop:
    _config : Config 
    _model : torch.nn.Module 
    _dataset : Dataset
    _collator : DataCollator
    _evaluator : Evaluator
    _optimizer : torch.optim.Optimizer
    _state : TrainingState
    _wandb_run : wandb.Run
    """Users should freely modify callback to add new callbacks"""
    callbacks : List[TrainingCallback]

    def __init__(self, config : Config, model : torch.nn.Module, dataset : Dataset,
                 collator: DataCollator, evaluator : Evaluator,
                 optimizer: torch.optim.Optimizer):
        self._config = config 
        self._state = TrainingState(model=model,
                                    batch_loss=0.0,
                                    epoch_loss=0.0,
                                    data_iter=dataset.mk_iterator())
        self._model = model 
        self._dataset = dataset
        self._collator = collator
        self._evaluator = evaluator
        self._optimizer = optimizer
        self._wandb_run = wandb.init(config=self._config.asdict(), 
                               project=self._config.project,
                               save_code=True,
                               magic=True,
                               config_exclude_keys=["project"])
        self.callbacks = []


    def run(self):
        for eix in range(self.config.nepochs):
            self._state.eix = eix
            for cb in self.callbacks:
                cb.before_epoch(self, self._state)
            self._state.epoch_loss = 0
            self.run_epoch(self)
            for cb in self.callbacks:
                cb.after_epoch(self, self._state)
        wandb.finish()

    def run_epoch(self):
        for bix in range(self.state.nbatches):
            self._state.bix = bix
            for cb in self.callbacks:
                cb.before_batch(self, self._state)
            self.run_batch(self)
            
        cur_validation = self._evaluator(self._dataset, self._state)
        wandb.log(cur_validation.asdict())
        is_pareto_frontier = True
        for past_validation in self._state.metrics:
            # if someone else is truly better than us,
            # then we should not be added
            if cur_validation.is_strictly_worse_than(past_validation):
                is_pareto_frontier = False


        self._state.metrics.append(cur_validation)
        if is_pareto_frontier:
            torch.save({
                "eix": self._state.eix,
                "model": self._model.state_dict(),
                "optim": self._optimizer.state_dict(),
            }, "model-{self._state.eix}.pt")

    def run_batch(self):
        size_per_batch = ceil_div(len(self._dataset), self._config.nbatches)
        nministeps = ceil_div(size_per_batch, self._config.ministep_size)
        self._state.batch_loss = torch.tensor(0.0, dtype=float)

        self._optimizer.zero_grad()
        for cb in self.callbacks:
            cb.before_batch(self, self._state)

        for msix in range(nministeps):
            self.run_ministep(self, msix)

        wandb.log({"batch_loss": self._state.batch_loss})
        
        self._optimizer.step()

        for cb in self.callbacks:
            cb.after_batch(self, self._state)

    def run_ministep(self, msix : int):
        # collate data for ministep
        data = self._collator(get_k_from_iter(self._state.data_iter, self._config.ministep_size))
        ms_loss = self._model.forward(data)
        self._state.batch_loss += ms_loss

In [35]:
@dataclass
class LinearConfig:
    din : int 
    dout : int 
    # Ax + b
    def Asize(self):
        return [din, dout]
    def bsize(self):
        return [dout]


class Linear(torch.nn.Module):
    config : LinearConfig
    W : torch.tensor 
    b : torch.tensor
    def __init__(self, config: LinearConfig):
        self.config = config
        self.W = torch.random(size=config.Asize)
        self.b = torch.random(size=config.bsize)
        
    def forward(self, x:torch.tensor):
        # note that batch dim is always last dimension.
        return einops.einsum(x, self.W, self.b, "b in, in out, out -> b out")
    


In [40]:
@dataclass
class MaskKind:
    MASKS = [CAUSAL, NONE] = list(range(2))
    kind : int 
    def __init__(self, kind : int ):
        assert kind in MaskKind.MASKS
        self.kind = kind 



@dataclass
class SingleHeadAttnConfig:
    din : int
    dembed : int 
    dout : int 
    
    @property
    def Qconfig(self) -> LinearConfig: 
        return LinearConfig(self.din, self.dembed)
    @property
    def Kconfig(self) -> LinearConfig:
        return LinearConfig(self.din, self.dembed)
    @property
    def Vconfig(self) -> LinearConfig:
        return LinearConfig(self.din, self.dout)

class SingleHeadAttn(torch.nn.Module):
    config : SingleHeadAttnConfig
    Q : Linear 
    K : Linear 
    V : Linear 
    def __init__(self, config : SingleHeadAttnConfig):
        self.config = config 
        self.Q = Linear(config.Qconfig)
        self.K = Linear(config.Kconfig)
        self.V = Linear(config.Vconfig)

    # note that the context is shared for /every/ xs.
    # xs : [NBATCH;NX;DIN]
    # ctxs:[NBATCH;NCTX;DIN]
    # out: [NBATCH;NX;DOUT]
    def forward(self, xs : torch.tensor, ctxs : torch.tensor, mask : MaskKind=None):
        # {e, o} < c < b (less 'feature' dimensions are outer).
        xsembeds = einops.einsum(self.Q.W, self.Q.b, xs, "i e, e, b x i -> b x e") # [NBATCH;NX;DEMBED]
        keys = einops.einsum(self.K.W, self.K.b, ctxs, "i e, e, b c i -> b c e") #[NBATCH;NCTX;DEMBED]
        vs = einops.einsum(self.V.W, self.V.b, ctxs, "i o, o, b c i -> b c o") # [NBATCH;NCTX;DOUT]
        # "attn matrix" (unnormalized)
        keys_dot_xs = einops.einsum(keys, xsembeds, "b c e, b x e -> b c x") # [NBATCH;NCTX;NX]
        if mask.kind == MaskKind.CAUSAL:
            # perform the fill upto the top
            keys_dot_xs = torch.masked_fill()
            raise RuntimeError("don't know how to fill mask with causal pattern.")

            keys_dot_xs = torch.masked_fill(input=keys_dot_xs, mask=mask, value=-torch.inf)
        # keys_dots_xs = [NBATCH;NCTX;NX]
        keys = torch.softmax(keys_dot_xs, dim=1) # take softmax along NCTX
        # for each word, compute the output representation.
        out = einops.einsum(keys, vs, "b c x, b c o -> b x o") # [NBATCH;NX;DOUT]
        return out



[1, 2] 1 2


In [41]:
# code for mask: https://pytorch.org/text/stable/_modules/torchtext/nn/modules/multiheadattention.html
@dataclass 
class MultiHeadAttnConfig:
    din : int 
    dembedSingleHead : int
    doutSingleHead : int 
    dout : int 
    nheads : int
    
    @property
    def singleHeadAttnConfig(self) -> SingleHeadAttnConfig:
        return SingleHeadAttnConfig(din=self.din, dembed=self.dembedSingleHead, dout=self.doutSingleHead)
    @property
    def finalLinearConfig(self) -> LinearConfig:
        return LinearConfig(din=self.nheads*self.doutSingleHead, dout=self.dout)

class MultiHeadAttn(torch.nn.Module):
    heads : List[SingleHeadAttn]
    Wout : Linear
    
    def __init__(self, config : MultiHeadAttnConfig):
        self.config = config 
        for hix in range(config.nheads):
            self.heads.append(SingleHeadAttn(config.singleHeadAttnConfig))
        self.Wout = Linear(config.finalLinearConfig)
    
    # xs : [NBATCH;NX;DEMBED]
    # zs: [NBATCH;NCTX;DEMBED;],
    # out: [NBATCH;NX;DOUT]
    def forward(self, xs : torch.tensor, zs : torch.tensor, mask : torch.tensor):
        ys = [head(xs, zs, mask) for head in self.heads] # ys = [NBATCH;]
        #ys : [NBATCH;(NHEADS*DOUT_SINGLE)]
        ys = einops.rearrange(ys, "b h o  -> b (h o)") # concatenat to get xs = [(NHEADS*DOUT_SINGLE);NBATCH]
        outs = self.Wout(ys) # unsure this does what I want it to do lol
        return outs 


In [None]:
@dataclass
class SelfAttentionDecoderConfig:
    max_input_length : int 
    nvocab: int #number of words in vocabulary to produce probability distribution over. 
    dembedSingleHead : int
    doutSingleHead : int  
    nheads : int 
    nlayers : int # number of times to perform self attention
    din : int # input embedding dimension
    dembedSingleHead : int 
    doutSingleHead : int 
    
    @property
    def wordEmbeddingShape(self):
        return (self.din, self.nvocab)

    @property 
    def multiHeadAttnConfig(self):
        return MultiHeadAttnConfig(dim=self.dim,
                                nheads=self.nheads,
                                dembedSingleHead=self.dembedSingleHead,
                                doutSingleHead=self.doutSingleHead,
                                dout=self.nvocab)

# decoder only language model (e.g. gpt)
class SelfAttentionDecoder(torch.nn.Module):
    attns : MultiHeadAttn
    word_embeds : torch.tensor
    def __init__(self, config : SelfAttentionDecoderConfig):
        self.config = config
        self.word_embeds = torch.random(shape=config.wordEmbeddingShape)
        self.attns = [MultiHeadAttn(self.config.multiHeadAttnConfig)]
    # return probabilities
    # xs: one hot encoding of words.
    # xs: [[word]] = [sentence]
    # outs: [NVOCAB;NWORDS;NBATCH] 
    # LorA at each step!
    # for us, NCONTEXT = NWORDS
    def foward(self, sentences: torch.tensor):
        # please don't disappoint me, einops.
        NWORDSMAX = max([len(sent) for sent in sentences])
        assert NWORDSMAX < self.config.max_input_length
        
        # how to efficiently create mask?
        # mask = ...
        mask = None 
        # FIXME: batch dimension is **FIRST DIMENSION**
        # xs: [NBATCH;NWORDSMAX;DIN]
        xs = einops.rearrange([[self.word_embeds[:, wix] for wix in sentence] for sentence in sentences] , "b w e -> e w b")
        
        for l in range(self.layers):
            # TODO: separate layer norm for each word in sentence, wtf.
            xs_tilde = torch.nn.functional.layer_norm(input=xs,
                                                      normalized_shape=(self.config.din, NWORDSMAX))
            xs = xs + self.attns(xs_tilde, xs_tilde, mask) # residual
            
            xs_tilde = torch.nn.functional.layer_norm(input=xs)            
            xs = xs + torch.nn.functional.gelu(...)
        