In [1]:
import torch
from torchtext.datasets import SST2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_data = SST2(split="train")
eval_data = SST2(split="dev")


In [2]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

sentences = []
labels = []

for text, label in train_data:
    sentences.append(text)
    labels.append(label)

tokenized_texts = [tokenizer.tokenize(sentence) for sentence in sentences]

MAX_LEN = 128
input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
input_ids = torch.tensor(
    [ids[:MAX_LEN] + [0] * (MAX_LEN - len(ids)) for ids in input_ids])

attention_masks = torch.tensor(
    [[1 if token_id > 0 else 0 for token_id in ids] for ids in input_ids])

labels = torch.tensor(labels)

len(input_ids)

67349

In [3]:
eval_sentences = []
eval_labels = []

for text, label in eval_data:
    eval_sentences.append(text)
    eval_labels.append(label)

eval_tokenized_texts = [tokenizer.tokenize(
    sentence) for sentence in eval_sentences]

MAX_LEN = 128
eval_input_ids = [tokenizer.convert_tokens_to_ids(
    x) for x in eval_tokenized_texts]
eval_input_ids = torch.tensor(
    [ids[:MAX_LEN] + [0] * (MAX_LEN - len(ids)) for ids in eval_input_ids])

eval_attention_masks = torch.tensor(
    [[1 if token_id > 0 else 0 for token_id in ids] for ids in eval_input_ids])
eval_labels = torch.tensor(eval_labels)

len(eval_input_ids)

872

In [5]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(input_ids, attention_masks, labels)
eval_dataset = TensorDataset(eval_input_ids, eval_attention_masks, eval_labels)

BATCH_SIZE = 16

train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [21]:
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.init_weights()


    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, attention_mask):
        embedded = self.embedding(text)
        output, (hidden, cell) = self.lstm(embedded)
        masked_output = output * attention_mask.unsqueeze(-1)
        attention_weights = torch.softmax(masked_output, dim=1)
        attention_output = torch.sum(attention_weights * output, dim=1)
        output = self.fc(attention_output)
        return output.squeeze()

In [25]:
LR = 0.001
num_epochs = 10
EMBED_DIM = 512
HIDDEN_DIM = 256

model = TextClassificationModel(
    len(tokenizer), embed_dim=EMBED_DIM, hidden_dim=HIDDEN_DIM, num_classes=1)
model.to(DEVICE)

criterion = nn.BCEWithLogitsLoss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [26]:
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    # round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()  # convert into float for division
    acc = correct.sum() / len(correct)
    return acc

In [28]:
epoch_acc = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_acc = 0
    
    for index, (input_ids, attention_masks, labels) in enumerate(train_dataloader):
        input_ids = input_ids.to(DEVICE)
        attention_masks = attention_masks.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        output = model(input_ids, attention_masks)

        loss = criterion(output.squeeze(), labels.float())

        acc = binary_accuracy(output, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += acc.item()

        print(f"Epoch {epoch+1}/{num_epochs} | Batch {index}/{len(train_dataloader)} | Loss: {loss.item():.4f} | Accuracy: {acc.item()*100:.2f}%")

    epoch_acc.append(f'Accuracy for epoch {epoch} ==> {total_acc / len(train_dataloader)}')

print(f'Loss: {total_loss / len(train_dataloader)}, Acc: {total_acc / len(train_dataloader)}')

for i in range(len(epoch_acc)):
    print(f'Acc for Epoch {i}: {epoch_acc[i]}')

Epoch 1/10 | Batch 0/4210 | Loss: 0.4562 | Accuracy: 87.50%
Epoch 1/10 | Batch 1/4210 | Loss: 0.0734 | Accuracy: 93.75%
Epoch 1/10 | Batch 2/4210 | Loss: 0.0680 | Accuracy: 100.00%
Epoch 1/10 | Batch 3/4210 | Loss: 0.2270 | Accuracy: 93.75%
Epoch 1/10 | Batch 4/4210 | Loss: 0.2760 | Accuracy: 87.50%
Epoch 1/10 | Batch 5/4210 | Loss: 0.3148 | Accuracy: 81.25%
Epoch 1/10 | Batch 6/4210 | Loss: 0.4668 | Accuracy: 81.25%
Epoch 1/10 | Batch 7/4210 | Loss: 0.1070 | Accuracy: 93.75%
Epoch 1/10 | Batch 8/4210 | Loss: 0.0679 | Accuracy: 100.00%
Epoch 1/10 | Batch 9/4210 | Loss: 0.1612 | Accuracy: 87.50%
Epoch 1/10 | Batch 10/4210 | Loss: 0.2954 | Accuracy: 87.50%
Epoch 1/10 | Batch 11/4210 | Loss: 0.1050 | Accuracy: 100.00%
Epoch 1/10 | Batch 12/4210 | Loss: 0.0805 | Accuracy: 93.75%
Epoch 1/10 | Batch 13/4210 | Loss: 0.2069 | Accuracy: 93.75%
Epoch 1/10 | Batch 14/4210 | Loss: 0.0600 | Accuracy: 100.00%
Epoch 1/10 | Batch 15/4210 | Loss: 0.1089 | Accuracy: 93.75%
Epoch 1/10 | Batch 16/4210 | L

In [38]:
data_iter = iter(eval_dataloader)
input_ids, attention_masks, labels = next(data_iter)

input_ids = input_ids.to(DEVICE)
attention_masks = attention_masks.to(DEVICE)
labels = labels.to(DEVICE)

output = model(input_ids, attention_masks)


# optimizer.zero_grad()

# sentiment_scores = output.mean(dim=1)

loss = criterion(output.squeeze(), labels.float())

In [48]:
model.eval()
corrects = 0
total = 0

with torch.no_grad():
    for input_ids, attention_masks, labels in eval_dataloader:

        input_ids = input_ids.to(DEVICE)
        attention_masks = attention_masks.to(DEVICE)
        labels = labels.to(DEVICE)

        output = model(input_ids, attention_masks)
        predicted_labels = (output > 0.5).long()

        corrects += (predicted_labels == labels).sum().item()
        total += len(labels)

accuracy = corrects / total

accuracy

0.7775229357798165

In [None]:
torch.save(model.state_dict(), './modelCheckpoint')