In [22]:
import torch as t
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
from einops import rearrange
from torch.nn import functional as F
from tqdm import tqdm
import random
from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional
from einops import rearrange
import wandb
from fancy_einsum import einsum

In [23]:
p = print
BDICT = {
    '(': ')',
    '[': ']',
    '{': '}',
}

In [24]:
def generate_bracket_string(maxsize):
    bracket_string = ''
    stack = []
    for _ in range(int(maxsize * 0.9)):
        if len(stack) == 0 or t.rand(1) < 0.5: # put new bracket
            bracket = random.choice(list(BDICT.keys()))
            stack.append(BDICT[bracket])
        else:
            bracket = stack.pop()
        bracket_string += bracket
    while len(stack) > 0:
        bracket_string += stack.pop()
    if len(bracket_string) > maxsize:
        return generate_bracket_string(maxsize)
    else:
        return bracket_string


In [25]:
# mean with these params is 517.3
BLEN = 500
ITERS = 1000
p(f'Average len: {sum([len(generate_bracket_string(BLEN)) for _ in range(ITERS)])/ITERS}')


Average len: 465.812


In [26]:
def isValid(s: str) -> bool:
        stack = []
        for char in s:
            try:
                if char == '(':
                    stack.append(char)
                elif char == ')':
                    if stack.pop() != '(':
                        return False
                elif char == '[':
                    stack.append(char)
                elif char == ']':
                    if stack.pop() != '[':
                        return False
                elif char == '{':
                    stack.append(char)
                elif char == '}':
                    if stack.pop() != '{':
                        return False
                else:
                    raise Exception('Invalid character')
            except:
                return False
        return len(stack) == 0

In [27]:
def make_invalid_bracket_string(size):
    bracket_string = generate_bracket_string(size)
    # change random brackets to invalid
    corrupt_size = random.randint(1, max(1, int(size * 0.1)))
    bracketlist = list('()[]{}')
    for _ in range(corrupt_size):
        bracket_string = bracket_string.replace(random.choice(bracketlist), random.choice(bracketlist))

    if isValid(bracket_string):
        return make_invalid_bracket_string(size)
    else:
        return bracket_string

assert [isValid(make_invalid_bracket_string(10)) for i in range(1000)].count(False) == 1000

In [28]:
# create dataset with bracket strings
# and labels for each bracket

# tokenizer for brackets
class BracketTokenizer:
    def __init__(self, vocab, maxlen):
        self.maxlen = maxlen
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.vocab_inv = {v: k for k, v in vocab.items()}
        self.pad_token_id = self.vocab['[PAD]']
        self.cls_token_id = self.vocab['[CLS]']
        self.sep_token_id = self.vocab['[SEP]']
        self.mask_token_id = self.vocab['[MASK]']
        self.unk_token_id = self.vocab['[UNK]']
        self.vocab_inv = {v: k for k, v in vocab.items()}
        self.vocab_inv[self.pad_token_id] = '[PAD]'
        self.vocab_inv[self.cls_token_id] = '[CLS]'
        self.vocab_inv[self.sep_token_id] = '[SEP]'
        self.vocab_inv[self.mask_token_id] = '[MASK]'

    def convert_tokens_to_ids(self, tokens):
        return [self.vocab[token] for token in tokens]

    def convert_ids_to_tokens(self, ids):
        return [self.vocab_inv[id] for id in ids]

    def __call__(self, text):
        return self.encode(text)

    def encode(self, text):
        assert len(text) + 2 <= self.maxlen
        ids = self.convert_tokens_to_ids(text)
        ids = [self.cls_token_id] + ids + [self.sep_token_id] # add cls and sep tokens
        ids = ids + [self.pad_token_id] * (self.maxlen - len(ids)) # pad to max len
        return ids

    def decode(self, ids):
        if type(ids) == t.Tensor:
            ids = ids.tolist()
        ids = [id for id in ids if id != self.pad_token_id]
        # remove cls and sep tokens if present
        if ids[0] == self.cls_token_id:
            ids = ids[1:]
        if ids[-1] == self.sep_token_id:
            ids = ids[:-1]
        tokens = ''.join(self.convert_ids_to_tokens(ids))
        return tokens

vocab = {'(': 0, ')': 1, '[': 2, ']': 3, '{': 4, '}': 5, '[PAD]': 6, '[CLS]': 7, '[SEP]': 8, '[MASK]': 9, '[UNK]': 10}
simplevocab = {'(': 0, ')': 1, '[PAD]': 2, '[CLS]': 3, '[SEP]': 4, '[MASK]': 5, '[UNK]': 6}
tokenizer = BracketTokenizer(vocab, maxlen=16)


In [29]:
bracketss = [generate_bracket_string(tokenizer.maxlen-2) for _ in range(10)]
for brackets in bracketss:
    print(brackets)
    assert len(tokenizer(brackets)) == tokenizer.maxlen
    assert tokenizer.decode(tokenizer(brackets)) == brackets

()[()]{{}}{}
()[]{}[][][]
(([]))[][[[]]]
{}[]{}{(){}}
{{{{}}{[]}}}
(()(()))()()
(({[]}))()[()]
[](([[()]]()))
(([{}{}()])())
()[]()[]{[]()}


In [30]:
class BracketDataset(TensorDataset):
    def __init__(self, size, tokenizer: BracketTokenizer, validfrac=0.7):
        self.tokenizer = tokenizer
        self.size = size
        self.validfrac = validfrac
        self.rng = random.Random(42)
        self.train = self._make_dataset()
        super().__init__(*self.train)
    
    def _make_dataset(self):
        validsize = int(self.size * self.validfrac)
        invalidsize = self.size - validsize
        randsize = lambda: 2 * self.rng.randint(2, self.tokenizer.maxlen // 2) - 2
        valid_bracket_strings = [self.tokenizer(generate_bracket_string(randsize())) for _ in zip(range(validsize))]
        valid_bracket_labels = [1] * validsize
        invalid_bracket_strings = [self.tokenizer(make_invalid_bracket_string(randsize())) for _ in range(invalidsize)]
        invalid_bracket_labels = [0] * invalidsize
        
        bracket_strings = valid_bracket_strings + invalid_bracket_strings
        bracket_labels = valid_bracket_labels + invalid_bracket_labels
        # shuffle
        zipped = list(zip(bracket_strings, bracket_labels))
        self.rng.shuffle(zipped)
        bracket_strings, bracket_labels = zip(*zipped)
        # to tensor
        bracket_strings = t.tensor(bracket_strings, dtype=t.long)
        bracket_labels = t.tensor(bracket_labels, dtype=t.long)
        return bracket_strings, bracket_labels

In [31]:
trainset = BracketDataset(size=4096, tokenizer=tokenizer)
testset = BracketDataset(size=512, tokenizer=tokenizer)

In [32]:
for x,y in trainset:
    assert isValid(tokenizer.decode(x)) == y.item()

In [33]:
for i, (x, y) in enumerate(trainset):
    print(tokenizer.decode(x), y.item())
    if i == 50:
        break

([][]))})} 0
{()}{()()} 1
{} 1
[][(())] 1
[{}]({}()) 1
))]()]()]{}) 0
]}[][]]} 0
()[]{}[{}] 1
(){[][ 0
({}) 1
[](){}{}{{}{}} 1
{[)}()[)[) 0
(({[]((}(((( 0
{}()[] 1
[[]](){)() 0
[{}]()({{}}) 1
()()() 1
[{}]{} 1
({[{}]})[[]] 1
[]()[]{{}[]} 1
()()[]([]) 1
())})[]}() 0
({{[]}})() 1
()[] 1
[()[ 0
[][()] 1
[][{}]}) 0
[](()) 1
{}{}({}){} 1
{}{{[]}}() 1
[]({}){}({}) 1
(}[] 0
([({})()][]) 1
()[()] 1
[]{}{[]()} 1
[] 1
{()( 0
])])]){}[]{])} 0
({})(()[[]()]) 1
(){] 0
{([[]])}[] 1
[[][(())]] 1
({}) 1
[]{}{}{}{} 1
{}[[}} 0
(()) 1
[] 1
() 1
({[] 0
()[] 1
{} 1


In [34]:
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=True)

In [35]:
for (x, y) in trainloader:
    print(x.shape)
    break

torch.Size([64, 16])


In [36]:
#@title Transformer Modules
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int
    num_heads: int
    vocab_size: int
    hidden_size: int # also embedding dim or d_model
    max_seq_len: int = 5000 
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    device = t.device('cuda' if t.cuda.is_available() else 'cpu')

In [37]:
class MultiheadAttention(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        hidden_size, num_heads = config.hidden_size, config.num_heads
        self.num_heads = num_heads
        self.W_Q = nn.Linear(hidden_size, hidden_size)
        self.W_K = nn.Linear(hidden_size, hidden_size)
        self.W_V = nn.Linear(hidden_size, hidden_size)
        self.W_O = nn.Linear(hidden_size, hidden_size)
        self.num_heads = num_heads

    def forward(self, x: t.Tensor, additive_attention_mask: Optional[t.Tensor] = None) -> t.Tensor:
        Q, K, V = self.W_Q(x), self.W_K(x), self.W_V(x)
        att = self.multihead_masked_attention(Q, K, V, self.num_heads, additive_attention_mask)
        return self.W_O(att)

    def multihead_masked_attention(self, Q: t.Tensor, K: t.Tensor, V: t.Tensor, n_heads: int, additive_attention_mask: Optional[t.Tensor]):
        '''
        Q: shape (b, s1, e)
        K: shape (b, s2, e)
        V: shape (b, s2, e)

        e = nheads * h
        b = batch
        s = seq_len
        h = hidden

        Return: shape (b s e)
        '''

        assert Q.shape[-1] % n_heads == 0
        assert K.shape[-1] % n_heads == 0
        assert V.shape[-1] % n_heads == 0
        assert K.shape[-1] == V.shape[-1]

        Q = rearrange(Q, 'b s (nheads h) -> b nheads s h', nheads=n_heads)
        K = rearrange(K, 'b s (nheads h) -> b nheads s h', nheads=n_heads)
        V = rearrange(V, 'b s (nheads h) -> b nheads s h', nheads=n_heads)

        batch, nheads, seq_len, headsize = Q.shape

        scaled_dot_prod = einsum('b nheads sk h, b nheads sq h -> b nheads sq sk', K, Q) / (headsize ** 0.5)
        if additive_attention_mask is not None:
            scaled_dot_prod += additive_attention_mask # (batch, 1, 1, sk)
        attention_probs = scaled_dot_prod.softmax(dim=-1)
        attention_vals = einsum('b nheads s1 s2, b nheads s2 c -> b nheads s1 c', attention_probs, V)
        attention = rearrange(attention_vals, 'b nheads s c -> b s (nheads c)')
        return attention

class BERTMLP(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        hs, p_dropout = config.hidden_size, config.dropout
        self.mlp = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(hs, hs * 4)),
            ('GELU', nn.GELU()),
            ('linear2', nn.Linear(hs * 4, hs)),   
            ('dropout', nn.Dropout(p_dropout))
        ]))

    def forward(self, x: t.Tensor):
        return self.mlp(x)

class BERTBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attention = MultiheadAttention(config)
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.mlp = BERTMLP(config)
        self.ln2 = nn.LayerNorm(config.hidden_size)
        
    def forward(self, x: t.Tensor, additive_attention_mask: Optional[t.Tensor] = None) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)
        additive_attention_mask: shape (batch, nheads=1, seqQ=1, seqK)
        '''
        h1 = self.ln1(self.attention(x, additive_attention_mask) + x) # TODO chain this
        h2 = self.ln2(self.mlp(h1) + h1)
        return h2

def make_additive_attention_mask(one_zero_attention_mask: t.Tensor, big_negative_number: float = -10000) -> t.Tensor:
    '''
    one_zero_attention_mask: 
        shape (batch, seq)
        Contains 1 if this is a valid token and 0 if it is a padding token.

    big_negative_number:
        Any negative number large enough in magnitude that exp(big_negative_number) is 0.0 for the floating point precision used.

    Out: 
        shape (batch, nheads=1, seqQ=1, seqK)
        Contains 0 if attention is allowed, big_negative_number if not.
    '''
    return rearrange((1 - one_zero_attention_mask) * big_negative_number, 'batch seq -> batch 1 1 seq')

class BertCommon(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.hidden_size)
        self.tokentype_emb = nn.Embedding(2, config.hidden_size)
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.dropout1 = nn.Dropout(config.dropout)
        self.bertblocks = nn.ModuleList([BERTBlock(config) for i in range(config.num_layers)])

    def forward(
        self,
        input_ids: t.Tensor,
        one_zero_attention_mask: Optional[t.Tensor] = None,
        token_type_ids: Optional[t.Tensor] = None,
    ) -> t.Tensor:
        '''
        input_ids: (batch, seq) - the token ids
        one_zero_attention_mask: (batch, seq) - only used in training, passed to `make_additive_attention_mask` and used in the attention blocks.
        token_type_ids: (batch, seq) - only used for NSP, passed to token type embedding.
        '''
        token_embedding = self.token_emb(input_ids) # (b, seq_len, emb)
        batch, seq_len = input_ids.shape
        positional_embedding = self.pos_emb(t.arange(seq_len, device=input_ids.device)) # (seq_len, emb)
        token_type_ids = token_type_ids if token_type_ids else t.zeros_like(input_ids)
        token_type_embedding = self.tokentype_emb(token_type_ids) # (b, seq_len, emb)
        x = self.dropout1(self.ln1(token_embedding + positional_embedding + token_type_embedding))
        mask = make_additive_attention_mask(one_zero_attention_mask) if one_zero_attention_mask is not None else None
        for block in self.bertblocks:
            x = block(x, mask)
        return x

class BertLanguageModel(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        hs = config.hidden_size
        self.bertcommon = BertCommon(config)
        self.linear = nn.Linear(hs * config.max_seq_len, 2)
        self.gelu = nn.GELU()
        self.ln = nn.LayerNorm(config.hidden_size)
        xavier = 1 / (config.vocab_size ** 0.5)
        self.unembed_bias = nn.parameter.Parameter(t.randn(config.vocab_size) * 2 * xavier - xavier) # N(-xavier, xavier)

    def forward(
        self,
        input_ids: t.Tensor,
        one_zero_attention_mask: Optional[t.Tensor] = None,
        token_type_ids: Optional[t.Tensor] = None,
    ) -> t.Tensor:
        '''
        input_ids: (batch, seq) - the token ids
        one_zero_attention_mask: (batch, seq) - only used in training, passed to `make_additive_attention_mask` and used in the attention blocks.
        token_type_ids: (batch, seq) - only used for NSP, passed to token type embedding.
        '''
        x = self.bertcommon(input_ids, one_zero_attention_mask, token_type_ids)
        # just finish with a binary classification
        # torch.Size([64, 512, 768]) torch.Size([768, 11])
        # flatten the sequence dimension
        x = rearrange(x, 'batch seq emb -> batch (seq emb)')
        x = self.linear(x)
        # x = self.gelu(x)
        # x = self.ln(x)
        return x


bertconfig = TransformerConfig(
    num_layers = 2,
    num_heads = 4,
    vocab_size = len(vocab),
    hidden_size = 100,
    max_seq_len = tokenizer.maxlen,
    dropout = 0.1,
    layer_norm_epsilon = 1e-12
)

my_bert = BertLanguageModel(bertconfig)

In [38]:
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

In [39]:
def make_optimizer(model: BertLanguageModel, config_dict: dict) -> t.optim.AdamW:
    '''
    Loop over model parameters and form two parameter groups:

    - The first group includes the weights of each Linear layer and uses the weight decay in config_dict
    - The second has all other parameters and uses weight decay of 0
    '''
    params1 = []
    params2 = []
    matches = ['W_O.weight', 'W_V.weight', 'W_Q.weight', 'W_K.weight', 'linear1.weight', 'linear2.weight', 'linear.weight']
    for name, param in model.named_parameters():
        if any([match in name for match in matches]):
            params1.append(param)
        else:
            params2.append(param)
    
    # 
    # 
    params = [
        {'params': params1, 'weight_decay': config_dict['weight_decay']},
        {'params': params2, 'weight_decay': 0, **config_dict}
    ]
    return t.optim.AdamW(params, lr=config_dict['lr'])

In [40]:
def lr_for_step(step: int, max_step: int, max_lr: float, warmup_step_frac: float):
    '''
    The authors used learning rate warmup from an unspecified value and an unspecified shape to a maximum of 1e-4 for the first 10,000 steps out of 1 million, and then linearly decayed to an unspecified value.

    From the repo, we can see in optimization.py that AdamW is used for the optimizer, that the warmup is linear and that the epsilon used for AdamW is 1e-6.

    Assume that the initial learning rate and the final learning rate are both 1/10th of the maximum, and that we want to warm-up for 1% of the total number of steps.
    Return the learning rate for use at this step of training.'''
    warmup_steps = int(max_step * warmup_step_frac)
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    else:
        return max_lr * (max_step - step) / (max_step - warmup_steps)

In [41]:
def accuracy(preds: t.Tensor, targets: t.Tensor) -> float:
    preds = preds.argmax(dim=-1)
    return (preds == targets).float().mean()


In [42]:
import os
import tqdm
import requests
def flat(x: t.Tensor) -> t.Tensor:
    """Combines batch and sequence dimensions."""
    return rearrange(x, "b s ... -> (b s) ...")

def bert_mlm_pretrain(model: BertLanguageModel, config_dict: dict, train_loader: DataLoader) -> None:
    '''Train using masked language modelling.'''
    model.train()
    model.to(device)
    opt = make_optimizer(model, config_dict)
    # settings=wandb.Settings(start_method="fork"))
    # wandb.init(project="bert-brackets", config=config_dict)
    # wandb.watch(model)
    run_name = wandb.run.name if wandb.run else 'bert-brackets'
    # tqdm progress bar of train loader annotated with epoch number
    # os.makedirs(f"models/{run_name}")
    # t.save(model.state_dict(), f"./models/{run_name}/{run_name}-e-1.pt")
    
    for epoch in range(config_dict['epochs']):
        progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
        for n_batch, (batch, target) in enumerate(progress_bar):
            # print(batch[0], target[0])

            # with t.inference_mode():
            #     print(f'{target.sum() / target.numel()}')
            # wandb.log({"epoch": epoch, "batch": n_batch})
            batch = batch.to(device)
            opt.zero_grad()
            # lr = lr_for_step(
            #     n_batch + epoch * len(train_loader),
            #     max_step=int(len(train_loader) * config_dict["epochs"]),
            #     max_lr=config_dict["lr"],
            #     warmup_step_frac=config_dict["warmup_step_frac"],
            # )
            # for param_group in opt.param_groups:
            #     param_group["lr"] = lr
            # masked_input_ids, mask = random_mask(batch, tokenizer.mask_token_id, tokenizer.vocab_size)
            # masked_input_ids = masked_input_ids.to(device)
            # mask = mask.to(device)
            mask = (batch != tokenizer.pad_token_id).float()
            mask.requires_grad = False
            logits = model(batch.to(device), mask, token_type_ids=None)
            # print(batch.shape)
            # print(logits.shape)
            # print(target.shape)
            # use inference mode below to get accuracy
            # with t.no_grad():
            #     for x, y in testloader:
            #         x = x.to(device)
            #         y = y.to(device)
            #         m = (x != tokenizer.pad_token_id).float()
            #         logits = model(x, m, token_type_ids=None)
            #         acc = accuracy(logits, y)
            #         wandb.log({"test_acc": acc})
            #         break
            loss = F.cross_entropy(logits, target.to(device).long())
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # wandb.log({"loss": loss, "lr": lr})
            opt.step()
            progress_bar.set_description(f"Loss {loss.item():.4f}")
            if n_batch % 100 == 0:
                with t.inference_mode():
                    tests = [
                        "((((",
                        "))((",
                        "()()",
                        "))",
                        "()"
                    ]
                    tensors = [t.tensor(tokenizer(s)).unsqueeze(0).to(device) for s in tests]
                    losses = [model(ts, (ts != tokenizer.pad_token_id).float(), token_type_ids=None) for ts in tensors]
                    [print(f"{s} {l}") for s, l in zip(tests, losses)]
                    progress_bar.set_description(f"Loss {loss.item():.4f}")

                    
        # t.save(model.state_dict(), f"./models/{run_name}/{run_name}-e{epoch}.pt")
    # wandb.finish()
    requests.post("https://ntfy.sh/arena-brackets", data=f"Done training {run_name} 🎉".encode(encoding='utf-8'))

config_dict = dict(
    lr=0.001,
    epochs=5,
    batch_size=64*4,
    weight_decay=0.01,
    mask_token_id=tokenizer.mask_token_id,
    warmup_step_frac=0.01,
    eps=1e-06,
    max_grad_norm=None,
)

bert_mlm_pretrain(my_bert, config_dict, trainloader)

Loss 1.6069:   0%|          | 0/64 [00:00<?, ?it/s]

(((( tensor([[-1.9332,  2.3217]])
))(( tensor([[-2.2118,  2.5835]])
()() tensor([[-2.1859,  2.3808]])
)) tensor([[-1.9444,  2.4374]])
() tensor([[-2.0990,  2.7319]])


Loss 0.5983: 100%|██████████| 64/64 [00:01<00:00, 44.95it/s]
Loss 0.5518:   6%|▋         | 4/64 [00:00<00:01, 33.99it/s]

(((( tensor([[ 0.1983, -0.0083]])
))(( tensor([[ 0.7751, -0.6380]])
()() tensor([[-0.4450,  0.3301]])
)) tensor([[ 0.5242, -0.5141]])
() tensor([[-0.1597, -0.0525]])


Loss 0.1941: 100%|██████████| 64/64 [00:01<00:00, 44.58it/s]
Loss 0.1922:  16%|█▌        | 10/64 [00:00<00:01, 45.27it/s]

(((( tensor([[ 3.9170, -2.9281]])
))(( tensor([[ 2.9779, -1.4640]])
()() tensor([[-1.2837,  1.5322]])
)) tensor([[ 4.1218, -2.5658]])
() tensor([[-1.5344,  1.5978]])


Loss 0.0101: 100%|██████████| 64/64 [00:01<00:00, 45.15it/s]
Loss 0.0125:  16%|█▌        | 10/64 [00:00<00:01, 44.91it/s]

(((( tensor([[ 5.0696, -4.8935]])
))(( tensor([[-2.5230,  3.0130]])
()() tensor([[-5.5734,  5.5346]])
)) tensor([[ 3.5120, -2.6335]])
() tensor([[-6.1492,  6.3551]])


Loss 0.0005: 100%|██████████| 64/64 [00:01<00:00, 44.98it/s]
Loss 0.0079:  16%|█▌        | 10/64 [00:00<00:01, 44.79it/s]

(((( tensor([[ 5.2483, -5.4290]])
))(( tensor([[-1.8967,  2.8358]])
()() tensor([[-5.7074,  5.7196]])
)) tensor([[ 5.8669, -5.6101]])
() tensor([[-6.7756,  6.5710]])


Loss 0.0010: 100%|██████████| 64/64 [00:01<00:00, 43.34it/s]


In [43]:
with t.inference_mode():
    s = ")("
    tens = t.tensor(tokenizer(s)).unsqueeze(0).to(device)
    print(tens)
    print(my_bert(tens, (tens != tokenizer.pad_token_id).float(), token_type_ids=None))
    # for batch, target in testloader:
    #     batch = batch.to(device)
    #     target = target.to(device)
    #     mask = (batch != tokenizer.pad_token_id).float()
    #     mask.requires_grad = False
    #     logits = my_bert(batch, mask, token_type_ids=None)
    #     print(accuracy(logits, target))


tensor([[7, 1, 0, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]])
tensor([[-3.4914,  3.9164]])
