In [1]:
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer, BertModel
from torch import nn
from torch.optim import Adam
from tqdm import tqdm

In [2]:
df = pd.read_csv('../data/gutenberg_paragraphs.csv')
df.dropna(inplace=True)
df.reset_index(drop=True, inplace=True)

#Stay with only 10% of the min count of paragraphs per author
min_count = df.groupby('Author').count().min()[0]
df = df.groupby('Author').head(int(min_count*0.10))
print(df.groupby('Author').count())


                                  Paragraph
Author                                     
Alcott, Louisa May                      103
Austen, Jane                            103
Brontë, Charlotte                       103
Christie, Agatha                        103
Dickens, Charles                        103
Dostoyevsky, Fyodor                     103
Doyle, Arthur Conan                     103
Dumas, Alexandre                        103
Hugo, Victor                            103
Marcus Aurelius, Emperor of Rome        103
Nietzsche, Friedrich Wilhelm            103
Poe, Edgar Allan                        103
Shakespeare, William                    103
Twain, Mark                             103
Verne, Jules                            103
Wells, H. G. (Herbert George)           103
Wilde, Oscar                            103


In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

authors = df['Author'].unique()
author2idx = {author: idx for idx, author in enumerate(authors)}
idx2author = {idx: author for idx, author in enumerate(authors)}

class Dataset(torch.utils.data.Dataset):

    def __init__(self, df):
        
        self.labels = [author2idx[author] for author in df['Author']]
        print(self.labels)
        self.texts = [ tokenizer(paragraph, padding='max_length', max_length = 512, truncation=True,
                                return_tensors="pt") for paragraph in df['Paragraph']]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [4]:
#https://towardsdatascience.com/text-classification-with-bert-in-pytorch-887965e5820f
class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, len(authors))
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer

In [5]:
def train(model, train_data, val_data, learning_rate, epochs):

    train, val = Dataset(train_data), Dataset(val_data)

    train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)

    use_cuda = torch.cuda.is_available()
    #use_cuda = False
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr= learning_rate)

    if use_cuda:

            model = model.cuda()
            criterion = criterion.cuda()

    for epoch_num in range(epochs):

            total_acc_train = 0
            total_loss_train = 0

            for train_input, train_label in tqdm(train_dataloader):

                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)

                output = model(input_id, mask)
                
                batch_loss = criterion(output, train_label.long())
                total_loss_train += batch_loss.item()
                
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc

                model.zero_grad()
                batch_loss.backward()
                optimizer.step()
            
            total_acc_val = 0
            total_loss_val = 0

            with torch.no_grad():

                for val_input, val_label in val_dataloader:

                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)

                    output = model(input_id, mask)

                    batch_loss = criterion(output, val_label.long())
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
            
            print(
                f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} | Train Accuracy: {total_acc_train / len(train_data): .3f} | Val Loss: {total_loss_val / len(val_data): .3f} | Val Accuracy: {total_acc_val / len(val_data): .3f}')
                  

In [6]:
def evaluate(model, test_data):

    test = Dataset(test_data)

    test_dataloader = torch.utils.data.DataLoader(test, batch_size=2)

    use_cuda = torch.cuda.is_available()
    #use_cuda = False
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda:

        model = model.cuda()

    total_acc_test = 0
    with torch.no_grad():

        for test_input, test_label in test_dataloader:

              test_label = test_label.to(device)
              mask = test_input['attention_mask'].to(device)
              input_id = test_input['input_ids'].squeeze(1).to(device)

              output = model(input_id, mask)
              acc = (output.argmax(dim=1) == test_label).sum().item()
              total_acc_test += acc
    
    print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')

In [7]:
np.random.seed(112)
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), 
                                     [int(.8*len(df)), int(.9*len(df))])

print(len(df_train),len(df_val), len(df_test))

1400 175 176


In [8]:
EPOCHS = 5
model = BertClassifier()
LR = 1e-6
              
train(model, df_train, df_val, LR, EPOCHS)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[7, 2, 4, 2, 14, 15, 15, 0, 11, 15, 7, 5, 7, 0, 12, 12, 5, 7, 10, 11, 3, 7, 8, 2, 0, 9, 9, 15, 5, 9, 10, 2, 16, 4, 11, 14, 5, 5, 10, 3, 0, 3, 3, 1, 10, 1, 11, 9, 0, 6, 14, 13, 7, 8, 16, 1, 10, 15, 15, 3, 3, 6, 2, 15, 1, 2, 7, 12, 16, 7, 8, 2, 4, 13, 11, 2, 11, 12, 16, 10, 14, 16, 9, 7, 8, 14, 4, 13, 9, 11, 15, 10, 11, 11, 12, 15, 1, 16, 4, 16, 10, 9, 14, 3, 0, 13, 9, 1, 10, 7, 3, 7, 15, 3, 3, 4, 3, 5, 5, 7, 10, 6, 6, 2, 14, 6, 15, 6, 13, 11, 15, 0, 7, 2, 15, 14, 2, 6, 10, 6, 10, 10, 6, 13, 6, 1, 0, 14, 3, 8, 9, 3, 14, 10, 14, 16, 1, 0, 14, 6, 9, 15, 5, 13, 16, 14, 6, 2, 7, 14, 6, 4, 7, 10, 8, 13, 3, 7, 16, 12, 11, 8, 3, 11, 5, 2, 8, 0, 7, 4, 16, 4, 11, 14, 15, 2, 8, 13, 9, 6, 0, 11, 6, 0, 15, 15, 4, 7, 14, 1, 4, 14, 1, 16, 5, 7, 5, 11, 6, 4, 14, 12, 13, 15, 0, 14, 7, 10, 12, 6, 10, 15, 3, 3, 2, 16, 14, 10, 4, 5, 2, 8, 15, 5, 3, 3, 12, 3, 7, 16, 12, 11, 7, 16, 10, 9, 0, 10, 6, 4, 15, 12, 9, 0, 14, 2, 6, 9, 5, 9, 0, 14, 0, 2, 3, 5, 8, 10, 8, 3, 2, 13, 4, 6, 16, 16, 12, 9, 11, 16, 12, 3, 

 17%|█▋        | 116/700 [00:17<01:27,  6.67it/s]


KeyboardInterrupt: 

In [None]:
evaluate(model, df_test)

In [None]:
#save model
torch.save(model.state_dict(), 'model2.pt')