In [102]:
import torch
from typing import Union, List

class CharacterLevelTokenizer:
    def __init__(self, max_len: int) -> None:
        self.max_len = max_len
        self.vocab = {
            '<PAD>': 0,
            '<MASK>': 1,
        }
        offset = len(self.vocab)
        for k, v in zip('abcdefghijklmnopqrstuvwxyz', range(offset, 27+offset)):
            self.vocab[k] = v
        
    def tokenize(self, word: str, labels: List[str]=None, return_pt=True) -> Union[torch.Tensor, List[int]]:
        word = map(lambda x: x if x != '_' else '<MASK>', list(word))

        res = []
        for c in word:
            res.append(self.vocab[c])
        
        res += [self.vocab['<PAD>']] * (self.max_len - len(res))
        
        if labels:
            labels = list(map(lambda x: self.vocab.get(x) if x else -100, labels))
            labels = labels + [-100] * (self.max_len - len(labels))
            labels = torch.tensor(labels).long()
            res = torch.tensor(res).long()
            return res, labels
        
        if return_pt:
            return torch.tensor(res).long()
        
        return res


In [103]:
tokenizer = CharacterLevelTokenizer(max_len=16)

In [104]:
import math
import random
from typing import Tuple, List

def random_character_masker(word: str) -> Tuple[str, List[str]]:
    unique_chars = set(word)
    num_to_mask = random.randint(1, len(unique_chars))
    chars_to_mask = random.sample(unique_chars, num_to_mask)
    labels = [None for _ in range(len(word))]
    for c in chars_to_mask:
        for i, cc in enumerate(word):
            if c == cc:
                labels[i] = c
        word = word.replace(c, '_')
    return word, labels

In [105]:
word='shivam'
for i in range(20):
    print(i, random_character_masker(word))

0 ('shi_am', [None, None, None, 'v', None, None])
1 ('_h___m', ['s', None, 'i', 'v', 'a', None])
2 ('shi_a_', [None, None, None, 'v', None, 'm'])
3 ('____a_', ['s', 'h', 'i', 'v', None, 'm'])
4 ('___v__', ['s', 'h', 'i', None, 'a', 'm'])
5 ('______', ['s', 'h', 'i', 'v', 'a', 'm'])
6 ('___v_m', ['s', 'h', 'i', None, 'a', None])
7 ('s__v__', [None, 'h', 'i', None, 'a', 'm'])
8 ('s__v_m', [None, 'h', 'i', None, 'a', None])
9 ('______', ['s', 'h', 'i', 'v', 'a', 'm'])
10 ('shi_a_', [None, None, None, 'v', None, 'm'])
11 ('sh____', [None, None, 'i', 'v', 'a', 'm'])
12 ('____am', ['s', 'h', 'i', 'v', None, None])
13 ('s_____', [None, 'h', 'i', 'v', 'a', 'm'])
14 ('__i_a_', ['s', 'h', None, 'v', None, 'm'])
15 ('_h____', ['s', None, 'i', 'v', 'a', 'm'])
16 ('shi___', [None, None, None, 'v', 'a', 'm'])
17 ('______', ['s', 'h', 'i', 'v', 'a', 'm'])
18 ('shi_a_', [None, None, None, 'v', None, 'm'])
19 ('_h____', ['s', None, 'i', 'v', 'a', 'm'])


In [106]:
word='shivam'
for i in range(20):
    masked_word, label = random_character_masker(word)
    print(i, tokenizer.tokenize(masked_word, label))

0 (tensor([20,  1, 10, 23,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]), tensor([-100,    9, -100, -100, -100,   14, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100]))
1 (tensor([20,  9, 10,  1,  2, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]), tensor([-100, -100, -100,   23, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100]))
2 (tensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([  20,    9,   10,   23,    2,   14, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100]))
3 (tensor([ 1,  1, 10,  1,  1, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]), tensor([  20,    9, -100,   23,    2, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100]))
4 (tensor([ 1,  9, 10,  1,  1, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]), tensor([  20, -100, -100,   23,    2, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100]))
5 (tensor([ 1,  1,  1, 23,  1,  1,  0,  0,  0,  0,  0,  0,  0

In [182]:
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

class CustomDataset(Dataset):
    def __init__(self, word_file: str, split: str, max_len: int):
        with open(word_file, 'r') as f:
            lines = f.readlines()

        lines = list(map(lambda x: x.replace('\n', ''), lines))
        self.split = split
        train, test = train_test_split(lines, test_size=0.2, random_state=42)
        del lines

        if split == 'train':
            self.words = train
        else:
            val, test = train_test_split(test, test_size=0.5, random_state=42)
            if split == 'val':
                self.words = val
            elif split == 'test':
                self.words = test
            else:
                raise ValueError(f'Split should be in train/val/test but got {split}.')
        
        self.tokenizer = CharacterLevelTokenizer(max_len=max_len)
        

    def __len__(self):
        return len(self.words)

    def __getitem__(self, index):
        word = self.words[index]
        masked_word, label = random_character_masker(word)
        masked_word, label = self.tokenizer.tokenize(masked_word, label)
        attention_mask = torch.ones_like(masked_word)
        idx = masked_word == self.tokenizer.vocab['<PAD>']
        attention_mask[idx] = 0
        return {
            'input_ids': masked_word,
            'labels': label,
            'attention_mask': attention_mask
        }

In [183]:
ds = CustomDataset(
    word_file='words_250000_train.txt',
    split='train',
    max_len=64
)
next(iter(ds))

{'input_ids': tensor([ 6, 25,  9,  1,  1,  1,  1,  1, 21,  1,  1,  8,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 'labels': tensor([-100, -100, -100,   10,   13,    2,   19,    2, -100,   10,   15, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}

In [184]:
for i in range(3):
    print(next(iter(ds)))

{'input_ids': tensor([ 1,  1,  1,  1, 13,  2, 19,  2,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0]), 'labels': tensor([   6,   25,    9,   10, -100, -100, -100, -100,   21,   10,   15,    8,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
{'input_ids': te

In [185]:
for split in ['train', 'val', 'test']:
    ds = CustomDataset(
        word_file='words_250000_train.txt',
        split=split,
        max_len=64
    )
    print(len(ds))

181840
22730
22730


In [186]:
import pytorch_lightning as pl
from transformers import BertModel, BertConfig
from torch.utils.data import DataLoader

In [192]:
import pdb

class BertClassifier(pl.LightningModule):
    def __init__(self, config, num_classes):
        super().__init__()
        self.bert = BertModel(config)
        self.classifier = torch.nn.Linear(config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs['last_hidden_state']
        # print(pooled_output.shape)
        logits = self.classifier(pooled_output)
        return logits
    
    def loss_fn(self, logits, labels):
        idx = labels != -100
        # pdb.set_trace()
        return torch.nn.functional.cross_entropy(logits[idx], labels[idx])

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def training_step(self, batch, batch_idx):
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits, batch['labels'])
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits, batch['labels'])
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = torch.nn.functional.cross_entropy(outputs, batch['labels'])
        self.log('test_loss', loss)

num_classes = 29  # Replace with the actual number of classes
max_len = 64
# Load the BERT configuration
config = BertConfig(
    vocab_size=num_classes,  # Replace with the actual vocab size
    hidden_size=64,
    num_hidden_layers=6,
    num_attention_heads=4,
    intermediate_size=64*2,
    max_position_embeddings=max_len,
)

# Instantiate the BERT-based classifier
bert_classifier = BertClassifier(config, num_classes)

# Create train, validation, and test datasets
train_dataset = CustomDataset(
    'words_250000_train.txt', 
    split='train', 
    max_len=max_len
)
val_dataset = CustomDataset(
    'words_250000_train.txt', 
    split='val', 
    max_len=max_len
)
test_dataset = CustomDataset(
    'words_250000_train.txt', 
    split='test', 
    max_len=max_len
)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

# Create a PyTorch Lightning Trainer
trainer = pl.Trainer(
    max_epochs=1,
    devices=1, 
    accelerator="cpu",
)

# Train the model using the Trainer
trainer.fit(bert_classifier, train_dataloader, val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name       | Type      | Params
-----------------------------------------
0 | bert       | BertModel | 211 K 
1 | classifier | Linear    | 1.9 K 
-----------------------------------------
213 K     Trainable params
0         Non-trainable params
213 K     Total params
0.852     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 5683/5683 [10:08<00:00,  9.34it/s, v_num=8]      

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 5683/5683 [10:08<00:00,  9.34it/s, v_num=8]


In [162]:
batch = next(iter(train_dataloader))

In [163]:
logits = bert_classifier(
    input_ids=batch['input_ids'], 
    attention_mask=batch['attention_mask']
)

torch.Size([32, 64, 64])


In [165]:
logits.shape

torch.Size([32, 64, 27])

In [170]:

batch['label'].unsqueeze(-1).shape

torch.Size([32, 64, 1])

In [178]:
idx = batch['label'] != -100
loss = torch.nn.functional.cross_entropy(logits[idx], batch['label'][idx])

torch.Size([174, 27])
torch.Size([174])
