### Imports

In [74]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import xml.etree.ElementTree as ET

from torch.utils.data import DataLoader

from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
from torch.optim import AdamW


### Seed

In [75]:
torch.manual_seed(42)

<torch._C.Generator at 0x10bb3a030>

### Dataset

In [76]:
class EntailmentDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.attention_masks = []
        self.labels = []

        with open(file_path, 'r', encoding='utf-8') as file:
            xml_data = file.read()

        root = ET.fromstring(xml_data)
        label_dict = {"NO": 0, "YES": 1, "UNKNOWN": 2}

        for pair in root.findall('pair'):
            t = pair.find('t').text
            h = pair.find('h').text
            label = pair.get('entailment')

            encodings = tokenizer(t, h, max_length=max_length, padding='max_length',
                                  truncation=True, return_tensors="pt", add_special_tokens=True)
            
            self.input_ids.append(encodings.input_ids.squeeze(0))
            self.attention_masks.append(encodings.attention_mask.squeeze(0))
            self.labels.append(label_dict[label])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': torch.tensor(self.labels[idx])
        }

    def decode(self, idx):
        return self.tokenizer.decode(self.input_ids[idx])

In [77]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
file_path = 'data.xml'
val_file_path = 'val_data.xml'
dataset = EntailmentDataset(file_path, tokenizer)
val_dataset = EntailmentDataset(val_file_path, tokenizer)

data_iterator = iter(dataset)
for _ in range(len(dataset)):
    batch = next(data_iterator)
    print("Decoded Text: ", dataset.decode(_))
    print("Labels: ", batch['labels'].item())


Decoded Text:  [CLS] crude oil for april delivery traded at $ 37. 80 a barrel, down 28 cents [SEP] crude oil prices rose to $ 37. 80 per barrel [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Labels:  0
Decoded Text:  [CLS] oracle had fought to keep the forms from being released [SEP] oracle released a confidential document [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

### Initialisation

In [78]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print('Using device:', device)

Using device: mps


In [79]:
dataset = EntailmentDataset('data.xml', tokenizer, max_length=128)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

val_dataset = EntailmentDataset('val_data.xml', tokenizer, max_length=128)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [80]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Train

In [81]:
def train(epoch):
    model.train()
    total_loss = 0
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    average_loss = total_loss / len(dataloader)
    print(f'Epoch: {epoch+1}, Loss: {average_loss:.4f}')

In [82]:
num_epochs = 10
for epoch in range(num_epochs):
    train(epoch)

Epoch: 1, Loss: 1.0477
Epoch: 2, Loss: 0.9571
Epoch: 3, Loss: 0.7972
Epoch: 4, Loss: 0.4888
Epoch: 5, Loss: 0.2458
Epoch: 6, Loss: 0.1291
Epoch: 7, Loss: 0.0937
Epoch: 8, Loss: 0.0534
Epoch: 9, Loss: 0.0345
Epoch: 10, Loss: 0.0295


### Test

In [83]:
model.eval()

def compute_accuracy(predictions, labels):
    return (predictions == labels).float().mean()
counter = 0
with torch.no_grad():
    for batch in val_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        predictions = torch.argmax(outputs.logits, dim=-1)
        accuracy = compute_accuracy(predictions, batch['labels'])
        print("Predictions:", predictions)
        print("Actual Labels:", batch['labels'])
        print(f"Accuracy: {accuracy.item():.4f}")
        counter += 1
        if counter == 5:
            break

Predictions: tensor([1, 2, 0, 1, 1, 0, 1, 1, 1, 2, 0, 1, 2, 2, 1, 2], device='mps:0')
Actual Labels: tensor([2, 1, 2, 2, 1, 1, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1], device='mps:0')
Accuracy: 0.4375
Predictions: tensor([2, 2, 1, 0, 0, 1, 0, 1, 1, 2, 1, 2, 1, 1, 1, 0], device='mps:0')
Actual Labels: tensor([1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 0], device='mps:0')
Accuracy: 0.4375
Predictions: tensor([0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 2, 0, 2, 1], device='mps:0')
Actual Labels: tensor([1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 1, 2, 1, 1], device='mps:0')
Accuracy: 0.1875
Predictions: tensor([1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 2, 2, 2, 1, 1], device='mps:0')
Actual Labels: tensor([1, 2, 1, 2, 2, 1, 0, 1, 1, 1, 2, 0, 1, 2, 1, 1], device='mps:0')
Accuracy: 0.3750
Predictions: tensor([2, 0, 2, 1, 1, 2, 1, 1, 1, 1, 2, 0, 2, 0, 1, 2], device='mps:0')
Actual Labels: tensor([1, 1, 0, 1, 1, 0, 1, 0, 1, 2, 2, 0, 1, 1, 1, 1], device='mps:0')
Accuracy: 0.4375
