In [1]:
from pathlib import Path
import torch

import polars as pl

from strip_headers import strip_headers
import re

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_scheduler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate
import torch.nn.functional as F


In [2]:
class GutenbergDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        
    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    
    def __len__(self):
        return len(self.labels)
    

In [3]:
def remove_people(people, text):
    new_text = text
    for person in people:
        new_text = re.sub(person, 'Person', new_text, flags=re.IGNORECASE)
    return new_text

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

data_dir = Path.cwd() / "data"
metadata = pl.read_csv(data_dir / "metadata.csv")



authors = ["Goethe, Johann Wolfgang von", "Schiller, Friedrich"]
author_mapping = {"Goethe, Johann Wolfgang von": 0,
                  "Schiller, Friedrich": 1}

metadata = metadata.filter(pl.col("language") == "['de']")


tokenizer = AutoTokenizer.from_pretrained("distilbert-base-german-cased")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-german-cased", num_labels=2)


Some weights of the model checkpoint at distilbert-base-german-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-german-cased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', '

In [5]:
with open("people.txt", "r") as f:
    #removing \n 
    people = [person[:-1] for person in f.readlines()]


In [6]:
metadata.filter((pl.col("author") == "Schiller, Friedrich") & (pl.col("subjects").str.contains("Drama")))


id,title,author,authoryearofbirth,authoryearofdeath,language,downloads,subjects,type
str,str,str,i64,i64,str,i64,str,str
"""PG47804""","""Die Räuber: Ei...","""Schiller, Frie...",1759,1805,"""['de']""",364,"""{'Tragedies', ...",
"""PG6383""","""Die Jungfrau v...","""Schiller, Frie...",1759,1805,"""['de']""",36,"""{'Tragedies', ...",
"""PG6496""","""Die Braut von ...","""Schiller, Frie...",1759,1805,"""['de']""",28,"""{'Sicily (Ital...",
"""PG6498""","""Kabale und Lie...","""Schiller, Frie...",1759,1805,"""['de']""",66,"""{'Love -- Dram...",
"""PG6499""","""Die Verschwöru...","""Schiller, Frie...",1759,1805,"""['de']""",14,"""{'Fiéschi, Gia...",
"""PG6505""","""Turandot, Prin...","""Schiller, Frie...",1759,1805,"""['de']""",23,"""{'China -- Dra...",
"""PG6518""","""Wallensteins L...","""Schiller, Frie...",1759,1805,"""['de']""",137,"""{'Wallenstein,...",
"""PG6525""","""Die Piccolomin...","""Schiller, Frie...",1759,1805,"""['de']""",41,"""{""Thirty Years...",
"""PG6549""","""Wallensteins T...","""Schiller, Frie...",1759,1805,"""['de']""",40,"""{'Wallenstein,...",
"""PG7939""","""Die Huldigung ...","""Schiller, Frie...",1759,1805,"""['de']""",6,"""{'German poetr...",


In [7]:
metadata.filter((pl.col("author") == "Goethe, Johann Wolfgang von") & (pl.col("subjects").str.contains("Drama")))


id,title,author,authoryearofbirth,authoryearofdeath,language,downloads,subjects,type
str,str,str,i64,i64,str,i64,str,str
"""PG10353""","""Satyros oder D...","""Goethe, Johann...",1749,1832,"""['de']""",2,"""{'Drama'}""",
"""PG10425""","""Torquato Tasso...","""Goethe, Johann...",1749,1832,"""['de']""",41,"""{'Tasso, Torqu...",
"""PG10426""","""Die natürliche...","""Goethe, Johann...",1749,1832,"""['de']""",8,"""{'Drama'}""",
"""PG10428""","""Die Aufgeregte...","""Goethe, Johann...",1749,1832,"""['de']""",6,"""{'Europe -- So...",
"""PG2054""","""Iphigenie auf ...","""Goethe, Johann...",1749,1832,"""['de']""",104,"""{'Iphigenia (M...",
"""PG21000""","""Faust: Eine Tr...","""Goethe, Johann...",1749,1832,"""['de']""",959,"""{'Legends -- G...",
"""PG2146""","""Egmont""","""Goethe, Johann...",1749,1832,"""['de']""",42,"""{'Tragedies', ...",
"""PG2229""","""Faust: Der Tra...","""Goethe, Johann...",1749,1832,"""['de']""",670,"""{'Legends -- G...",
"""PG2230""","""Faust: Der Tra...","""Goethe, Johann...",1749,1832,"""['de']""",279,"""{'Legends -- G...",
"""PG2321""","""Götz von Berli...","""Goethe, Johann...",1749,1832,"""['de']""",26,"""{'Berlichingen...",


In [8]:
#train test split, preprocessing, tokenization and splitting into blocks in no particular order

train_encodings = {'input_ids'      : torch.tensor([], requires_grad=False, dtype=torch.long),
                   'token_type_ids' : torch.tensor([], requires_grad=False, dtype=torch.long),
                   'attention_mask' : torch.tensor([], requires_grad=False, dtype=torch.long),
                   }
train_labels = []

test_encodings = {'input_ids'      : torch.tensor([], requires_grad=False, dtype=torch.long),
                  'token_type_ids' : torch.tensor([], requires_grad=False, dtype=torch.long),
                  'attention_mask' : torch.tensor([], requires_grad=False, dtype=torch.long),
                  }
test_labels = []

val_encodings = {'input_ids'      : torch.tensor([], requires_grad=False, dtype=torch.long),
                 'token_type_ids' : torch.tensor([], requires_grad=False, dtype=torch.long),
                 'attention_mask' : torch.tensor([], requires_grad=False, dtype=torch.long),
                 }
val_labels = []


train_encodings ={'input_ids'      : [],
                  'token_type_ids' : [],
                  'attention_mask' : [],
                  }
train_labels = []

test_encodings ={'input_ids'      : [],
                  'token_type_ids' : [],
                  'attention_mask' : [],
                  }
test_labels = []

val_encodings ={'input_ids'      : [],
                  'token_type_ids' : [],
                  'attention_mask' : [],
                  }
val_labels = []


train_ids = []
test_ids = []
val_ids = []

removed_chars = "\r\n\t.:~()[]{}"


for author in authors:
    
    #select texts by the authors from the whole gutenberg corpus
    author_ids = metadata.filter((pl.col("author") == author) & (pl.col("subjects").str.contains("Drama")))["id"].to_list()
    
    #splitting in such a way that the model has never seen any parts of the play before
    #otherwise it would probably just learn to recognize the names of the characters
    #it will still ... 
    #I will just have very bad test loss
    
    #TODO: implement name removal
    #https://stackoverflow.com/questions/53534376/removing-names-from-noun-chunks-in-spacy
    
    train_ids += author_ids[:-2]
    test_ids += [author_ids[-2]]
    val_ids += [author_ids[-1]] 
    
    
    for doc_id in author_ids:
        
        if doc_id in ["PG21000", "PG2229"]:
            continue
        
        file_path = data_dir / "raw" / (doc_id + "_raw.txt")
        
        try:
            with open(file_path, "r") as in_f:
                raw_text = in_f.read()    
        except FileNotFoundError:
            print(f"Warning file not found{file_path}")
            continue
            
        
        #script for removing the gutenberg project headers        
        text = strip_headers(raw_text)         
        
        #removing all people from the people file
        if doc_id in train_ids:
            text = remove_people(people, text)
           
        
        #more custom header and footer stripping
        
        #this book has a longer appendix
        if doc_id =='PG47804':
            
            a = re.search(r'\b(Fußnoten)\b', text)
            text = text[1000:-a.start()]
            
        # for the rest we just strip another 1000 chars
        else: 
            text = text[1000:-1000]

        text.strip(removed_chars) 
        #encoding the data first 
        #since I probably want to use the full 512 tokens without doing any truncating or padding
        encoding = tokenizer(text, return_tensors="pt")
        
        # might be inefficient but fine for now
        # input ids here are token ids not doc ids ...

        encodings = {'input_ids'      : torch.split(encoding['input_ids'], 32, dim=1)[:-1],
                     #'token_type_ids' : torch.split(encoding['token_type_ids'], 32, dim=1)[:-1],
                     'attention_mask' : torch.split(encoding['attention_mask'], 32, dim=1)[:-1],
                    }
                     
        
        #not dry it's pretty damp
        if doc_id in train_ids:
            train_encodings = {key : list(encodings[key]) + train_encodings[key] for key in encodings}
            train_labels += [author_mapping[author] for _ in range(len(encodings['input_ids']))]
            # train_labels += [author_mapping[author].clone().detach() for _ in range(len(encoding['input_ids']))]
        elif doc_id in test_ids:
            test_encodings = {key : list(encodings[key]) + test_encodings[key] for key in encodings}
            test_labels += [author_mapping[author] for _ in range(len(encodings['input_ids']))]
            # test_labels += [author_mapping[author].clone().detach() for _ in range(len(encoding['input_ids']))]
        elif doc_id in val_ids:
            val_encodings = {key : list(encodings[key]) + val_encodings[key] for key in encodings}
            val_labels += [author_mapping[author] for _ in range(len(encodings['input_ids']))]
            # val_labels += [author_mapping[author].clone().detach() for _ in range(len(encoding['input_ids']))]


Token indices sequence length is longer than the specified maximum sequence length for this model (4563 > 512). Running this sequence through the model will result in indexing errors


In [9]:
# train_encodings["input_ids"]      = torch.split(torch.cat(train_encodings['input_ids'],dim=0),16, dim=0)
# train_encodings["token_type_ids"] = torch.split(torch.cat(train_encodings['token_type_ids'],dim=0),16, dim=0)
# train_encodings["attention_mask"] = torch.split(torch.cat(train_encodings['attention_mask'],dim=0),16, dim=0)
# train_labels = torch.split(torch.cat(train_labels, dim=0), 16, dim=0)


# test_encodings["input_ids"]      = torch.split(torch.cat(test_encodings['input_ids'],dim=0),16, dim=0)
# test_encodings["token_type_ids"] = torch.split(torch.cat(test_encodings['token_type_ids'],dim=0),16, dim=0)
# test_encodings["attention_mask"] = torch.split(torch.cat(test_encodings['attention_mask'],dim=0),16, dim=0)
# test_labels = torch.split(torch.cat(test_labels, dim=0), 16, dim=0)




# val_encodings["input_ids"]      = torch.split(torch.cat(val_encodings['input_ids'],dim=0),16, dim=0)
# val_encodings["token_type_ids"] = torch.split(torch.cat(val_encodings['token_type_ids'],dim=0),16, dim=0)
# val_encodings["attention_mask"] = torch.split(torch.cat(val_encodings['attention_mask'],dim=0),16, dim=0)
# val_labels = torch.split(torch.cat(val_labels, dim=0), 16, dim=0)


In [10]:
for split in [train_encodings, test_encodings, val_encodings]:
    for key in split.keys():
        split[key] = [ torch.squeeze(seq) for seq in split[key]]

In [11]:
train_dataset = GutenbergDataset(train_encodings, train_labels)
test_dataset = GutenbergDataset(test_encodings, test_labels)
val_dataset = GutenbergDataset(val_encodings, val_labels)



In [12]:

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
eval_loader = DataLoader(val_dataset, batch_size=8)




In [16]:
train_metric = evaluate.load("accuracy")
val_metric = evaluate.load("accuracy")
num_epochs = 2
num_training_steps = num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
optim = AdamW(model.parameters(), lr=7e-4)
lr_scheduler = get_scheduler(name="linear", optimizer=optim, num_warmup_steps=500, num_training_steps=num_training_steps)

  0%|          | 0/4188 [00:00<?, ?it/s]



In [None]:

progress_bar = tqdm(range(num_training_steps), desc='Bar descr',)

model.to(device)
losses = []
train_acc = 0
val_acc = 0
for epoch in range(num_epochs):
    model.train()
    counter = 0
    for batch in train_loader:
        counter += 1
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device).squeeze()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optim.step()
        losses.append(loss.item())
        if counter % 100 == 0:
            progress_bar.set_description(f"Train Loss = {sum(losses) / len(losses):.4f} | Train Acc = {train_acc:.4f} Val Acc = {val_acc:.4f}")
        lr_scheduler.step()
        progress_bar.update(1)
        
        
    model.eval()
    counter = 0
    
    for batch in train_loader:
        counter += 1

        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():

            outputs = model(**batch)

        logits = outputs.logits
        

        predictions = torch.argmax(logits, dim=-1)
        

        train_metric.add_batch(predictions=predictions, references=batch["labels"])
        
        if counter == 200:
            break

    
    for batch in eval_loader:

        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():

            outputs = model(**batch)

        logits = outputs.logits
        

        predictions = torch.argmax(logits, dim=-1)
        

        val_metric.add_batch(predictions=predictions, references=batch["labels"])

    train_acc = train_metric.compute()["accuracy"]
    val_acc = val_metric.compute()["accuracy"]
    progress_bar.set_description(f"Train Loss = {sum(losses) / len(losses):.4f} | Train Acc = {train_acc:.4f} Val Acc = {val_acc:.4f}")


Bar descr:   0%|          | 0/4188 [00:00<?, ?it/s]

In [None]:
model.save_pretrained("model.save")
