In [1]:
%cd ..

/mnt/data4/haryoaw_workspace/projects/2021_2/s4_happy/happy_s4


In [2]:
import datasets
from typing import List

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
mrpcc_data = datasets.load_dataset("glue", name="mrpc", split="train")

Reusing dataset glue (/mnt/data1/hf_dataset_cache/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [4]:
from collections import Counter

In [5]:
from torchtext.vocab import vocab

In [6]:
import torchtext

In [7]:
import re

In [127]:
class WordTokenizer:
    
    def __init__(self, special_tokens: List[str] = None, min_freq: int = 1, 
                 split_pattern: str = r"\s+", unk_token: str = "[UNK]", pad_token: str = '[PAD]',
                 eos_token: str = "[EOS]", sos_token: str="[SOS]",
                 lowercase: bool = False):
        """
        Word Tokenizer
        
        Parameters
        ----------
        lowercase: bool
            Lowercase?
        special_tokens: List[str]
            List of special tokens
        min_freq: int
            Minimum frequency of the token, by default = 1
        split_pattern: str
            Tokenizer pattern to distinguish the tokens, by default `\s+` (space)
        """
        self.pad_token = pad_token.lower() if lowercase else pad_token
        self.unk_token = unk_token.lower() if lowercase else unk_token
        self.sos_token = sos_token.lower() if lowercase else sos_token
        self.eos_token = eos_token.lower() if lowercase else eos_token
        
        sp_token_default = [pad_token, unk_token, eos_token, sos_token] if not lowercase else [pad_token.lower(), unk_token.lower(), eos_token.lower(), sos_token.lower()]
        self.special_tokens = sp_token_default if special_tokens is None else sp_token_default + special_tokens
        self.min_freq = min_freq
        self.split_pattern = split_pattern
        self.lowercase = lowercase
        
        # Fit variables
        self.vocab = None
        self.pad_token_id = None
        self.unk_token_id = None
        self.eos_token_id = None
        self.sos_token_id = None
        
    def fit(self, *data: List[str]) -> None:
        """
        Fit data's vocabulary
        
        Parameters
        ----------
        data: List[str]
            The dataset 
        """
        all_tokens = []
        for dt in data:
            for x in dt:
                x = x.lower() if self.lowercase else x
                all_tokens += re.split(self.split_pattern, x)
        self.vocab = vocab(
            Counter(all_tokens), 
            min_freq=self.min_freq, 
            specials=self.special_tokens,
        )
        self.pad_token_id = self.vocab[self.pad_token]
        self.unk_token_id = self.vocab[self.unk_token]
        self.eos_token_id = self.vocab[self.eos_token]
        self.sos_token_id = self.vocab[self.sos_token]
    
    def tokenize_to_ids(self, inp: str, with_eos_sos: bool = False) -> List[str]:
        """
        Tokenize a text input to its input indices
        
        Parameters
        ----------
        inp: str
            Text input
        with_eos_sos: bool
            Concate it with end of sentence and start of sentence token?
        
        Returns
        -------
        List[str]
            tokenized text (indices)
        """
        inp_ready = inp.lower() if self.lowercase else inp
        splitted_txt = re.split(self.split_pattern, inp_ready)
        ids_splitted_txt = [self.vocab[x] if x in self.vocab else self.vocab[self.unk_token] for x in splitted_txt]
        if with_eos_sos:
            ids_splitted_txt = [self.sos_token_id]  + ids_splitted_txt + [self.eos_token_id]
        return ids_splitted_txt

In [136]:
wt = WordTokenizer(lowercase=True, special_tokens=['[sep]', '[cls]'])
wt.fit(mrpcc_data['sentence1'], mrpcc_data['sentence2'])

In [137]:
wt.tokenize_to_ids("[sep]")

[4]

In [138]:
# TODO create padding

In [139]:
# TODO modelling :3

In [140]:
def shape_dataset(inp_dict, word_tokenizer, text_cols = ['sentence1', 'sentence2'], label_col='label'):
    input_ids = word_tokenizer.tokenize_to_ids('[CLS]')
    for col in text_cols:
        input_ids += word_tokenizer.tokenize_to_ids(inp_dict[col])
        input_ids += word_tokenizer.tokenize_to_ids('[SEP]')
    returned_dict = dict(label=inp_dict[label_col], input_ids=input_ids)
    return returned_dict
    

In [141]:
inp_ex = {
    'sentence1': "hi my name is bejo",
    'sentence2': "Meong Meong",
    'label': 0
}

In [142]:
shape_dataset(inp_ex, word_tokenizer=wt)

{'label': 0, 'input_ids': [5, 1, 2634, 1286, 281, 1, 4, 1, 1, 4]}

In [143]:
from functools import partial

In [144]:
mrpcc_data = mrpcc_data.map(partial(shape_dataset, word_tokenizer=wt))

100%|█████████████████████████████████████████████████████████████████████████████| 3668/3668 [00:00<00:00, 6145.25ex/s]


In [171]:
mrpcc_data.column_names

['sentence1', 'sentence2', 'label', 'idx', 'input_ids']

In [173]:
mrpcc_data = mrpcc_data.remove_columns(['sentence1', 'sentence2', 'idx'])

In [174]:
from pprint import pprint

In [194]:
import torch

In [330]:
class BatchCollators:
    
    def __init__(self, pad_token_ids):
        self.pad_token_ids = pad_token_ids
        self.pad_strategy = 'max_length'
    
    def _pad_helper(self, x, max_length):
        cur_len = len(x)
        added_pad_len = max_length - cur_len
        padded_x = x + [self.pad_token_ids] * added_pad_len
        return padded_x

    def _pad_self(self, input_ids):
        max_length = max(map(len, input_ids))
        padded_input_ids = [self._pad_helper(x, max_length) for x in input_ids]
        return padded_input_ids
            
    def __call__(self, inp):
        input_ids, labels = [], []
        for i in inp:
            input_ids.append(i['input_ids'])
            labels.append(i['label'])
        input_ids = self._pad_self(input_ids)
        return {
            'input_ids': torch.LongTensor(input_ids),
            'label': torch.LongTensor(labels)
        }

In [331]:
from torch.utils.data import DataLoader

In [332]:
dl = DataLoader(dataset=mrpcc_data, collate_fn=BatchCollators(wt.pad_token_id), batch_size=5)

In [333]:
from torch.nn.modules import Module

In [334]:
from torch.nn import Embedding

In [335]:
from happy_s4.model.s4 import S4

In [336]:
from dataclasses import dataclass

In [337]:
@dataclass
class S4_GO_BRR_ARGS:
    pad_token_id: int
    vocab_size: int
    d_model: int
    l_max: int
    channels: int
    bidirectional: bool = True
    trainable: bool = True
    lr: float = 0.001
    tie_state: bool = True
    hurwitz: bool = True,
    transposed: bool = True
    pool: str = "last"
        
    def get_s4_args(self):
        args = ['d_model', 'l_max', 'channels', 'bidirectional', 'trainable', 'lr', 'tie_state',
                'hurwitz', 'transposed']
        dict_returned = {arg: self.__dict__[arg] for arg in args}
        return dict_returned

In [338]:
trainable = {"dt": True, "A": True, "P": True, "B": True}

In [339]:
pad_token_id = wt.pad_token_id

In [340]:
vocab_size = len(wt.vocab.get_itos())

In [341]:
args = S4_GO_BRR_ARGS(pad_token_id=pad_token_id, 
          vocab_size=vocab_size,
          d_model=128, l_max=512, channels=1, bidirectional=True, trainable=trainable, lr=0.001, tie_state=True, hurwitz=True, transposed=False)

In [342]:
class S4_GO_BRR(Module):
    
    def __init__(self, args: S4_GO_BRR_ARGS):
        super().__init__()
        self.s4 = S4(**args.get_s4_args())
        self.args=args
        self.embedding = Embedding(num_embeddings=args.vocab_size, embedding_dim=args.d_model, padding_idx=args.pad_token_id)
    
    def forward(self, input_ids):
        s4_out = self.embedding(input_ids)
        forward_s4 = self.s4(s4_out)
        if self.args.pool == "last":
            pooled_out = forward_s4[0][:,-1]
        return pooled_out
        

In [343]:
import torch.nn.functional as F

In [344]:
class S4_GO_BRR_Classification(Module):
    
    def __init__(self, args: S4_GO_BRR_ARGS):
        super().__init__()
        self.backbone = S4_GO_BRR(args)
        self.args = args
    
    def forward(self, input_ids, labels):
        s4_out = self.backbone(input_ids)
        loss = F.cross_entropy(s4_out, labels)
        return loss, s4_out

In [345]:
batch = next(iter(dl))

In [346]:
input_ids = batch['input_ids']

In [347]:
model = S4_GO_BRR(args)

In [348]:
model = S4_GO_BRR_Classification(args)

In [349]:
batch['label']

tensor([1, 0, 1, 0, 1])

In [351]:
loss, out = model(input_ids, batch['label'])

In [352]:
from torch.optim import Adam

In [353]:
optim = Adam(model.parameters())

In [354]:
loss.backward()

In [355]:
optim.step()

In [300]:
input_ids.shape

torch.Size([5, 55])

In [302]:
model(input_ids)

torch.Size([5, 128])