<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fclassification/applications/classification/natural_language_inference/NLI%20with%20BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install sh
!pip install nlp
!pip install transformers
!pip install pytorch_lightning



In [2]:
import sh

import nlp
import transformers
import torch as th
import pytorch_lightning as pl

In [3]:
device = th.device('cuda' if th.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
debug = False
epochs = 2
batch_size = 16
lr = 1e-2
momentum = 0.9
model_type = 'bert-base-uncased'
model_hidden = 768
num_classes = 3
seq_length = 100
percent = 2
dropout = 0.3

sh.rm('-r', '-f', 'logs')
sh.mkdir('logs')



In [5]:
class SNLIClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = transformers.BertModel.from_pretrained(model_type)
        self.classifier = th.nn.Linear(model_hidden, num_classes)
        self.loss = th.nn.CrossEntropyLoss(reduction='none')

        self.dropout = th.nn.Dropout(dropout)

    def prepare_data(self, split="train"):
        tokenizer = transformers.BertTokenizer.from_pretrained(model_type)
        def _tokenize(x):
            encoded = tokenizer.batch_encode_plus(
                    x['premise'],
                    x['hypothesis'],
                    max_length=seq_length, 
                    pad_to_max_length=True,
                    verbose=False)
            x['input_ids'] = encoded['input_ids']
            x['token_type_ids'] = encoded['token_type_ids']
            # There are -1 labels in the data
            # I am converting them into 0. Probably there is a better way to handle it
            x['output'] = [lab + 1 if lab == -1 else lab for lab in x['label']]
            return x

        def _prepare_ds(split):
            ds = nlp.load_dataset('snli', split=f'{split}[:{batch_size if debug else f"{percent}%"}]')
            ds = ds.map(_tokenize, batched=True)
            ds.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'output'])
            return ds

        self.train_ds, self.val_ds = map(_prepare_ds, ('train', 'validation'))

    def forward(self, input_ids, token_type_ids):
        mask = (input_ids != 0).float()
        outputs = self.model(input_ids, mask, token_type_ids)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

    def training_step(self, batch, batch_idx):
        logits = self.forward(batch['input_ids'], batch['token_type_ids'])
        loss = self.loss(logits, batch['output']).mean()
        return {'loss': loss, 'log': {'train_loss': loss}}

    def validation_step(self, batch, batch_idx):
        logits = self.forward(batch['input_ids'], batch['token_type_ids'])
        loss = self.loss(logits, batch['output'])
        acc = (logits.argmax(-1) == batch['output']).float()
        return {'loss': loss, 'acc': acc}
    
    def validation_epoch_end(self, outputs):
        loss = th.cat([o['loss'] for o in outputs], 0).mean()
        acc = th.cat([o['acc'] for o in outputs], 0).mean()
        out = {'val_loss': loss, 'val_acc': acc}
        return {**out, 'log': out}

    def train_dataloader(self):
        return th.utils.data.DataLoader(
                self.train_ds,
                batch_size=batch_size,
                drop_last=True,
                shuffle=True,
                )

    def val_dataloader(self):
        return th.utils.data.DataLoader(
                self.val_ds,
                batch_size=batch_size,
                drop_last=False,
                shuffle=False,
                )

    def configure_optimizers(self):
        return th.optim.SGD(
            self.parameters(),
            lr=lr,
            momentum=momentum,
        )

In [6]:
model = SNLIClassifier()
trainer = pl.Trainer(
    default_root_dir='logs',
    gpus=(1 if th.cuda.is_available() else 0),
    max_epochs=epochs,
    fast_dev_run=debug,
    logger=pl.loggers.TensorBoardLogger('logs/', name='snli', version=0),
)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
100%|██████████| 12/12 [00:03<00:00,  3.17it/s]
100%|██████████| 1/1 [00:00<00:00, 12.73it/s]

  | Name       | Type             | Params
------------------------------------------------
0 | model      | BertModel        | 109 M 
1 | classifier | Linear           | 2 K   
2 | loss       | CrossEntropyLoss | 0     
3 | dropout    | Dropout          | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [8]:
!ls -lah logs/snli/version_0/checkpoints

total 836M
drwxr-xr-x 2 root root 4.0K Jul  2 18:27  .
drwxr-xr-x 3 root root 4.0K Jul  2 18:16  ..
-rw-r--r-- 1 root root 836M Jul  2 18:27 'epoch=1.ckpt'


In [10]:
CKPT_PATH = 'logs/snli/version_0/checkpoints/epoch=1.ckpt'
checkpoint = th.load(CKPT_PATH)