In [1]:
from pathlib import Path
import torch

import polars as pl

from strip_headers import strip_headers
import re

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_scheduler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate

import spacy

In [2]:
class GutenbergDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        
    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    
    def __len__(self):
        return len(self.labels)
    

In [3]:
def remove_people(people, text):
    new_text = text
    for person in people:
        new_text = re.sub(person, 'Person', new_text, flags=re.IGNORECASE)
    return new_text

In [4]:
!python -m spacy download de_core_news_sm

Collecting de-core-news-sm==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.5.0/de_core_news_sm-3.5.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m61.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[0m[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('de_core_news_sm')


In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

data_dir = Path.cwd() / "data"
metadata = pl.read_csv(data_dir / "metadata.csv")



authors = ["Goethe, Johann Wolfgang von", "Schiller, Friedrich"]
author_mapping = {"Goethe, Johann Wolfgang von": 0,
                  "Schiller, Friedrich": 1}
metadata = metadata.filter(pl.col("language") == "['de']")


tokenizer = AutoTokenizer.from_pretrained("bert-base-german-cased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-german-cased", num_labels=2)

nlp = spacy.load("de_core_news_sm")


Some weights of the model checkpoint at bert-base-german-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoi

In [6]:
with open("people.txt", "r") as f:
    #removing \n 
    people = [person[:-1] for person in f.readlines()]


In [7]:
#train test split, preprocessing, tokenization and splitting into blocks in no particular order

train_encodings = {'input_ids'      : torch.tensor([], requires_grad=False, dtype=torch.long),
                   'token_type_ids' : torch.tensor([], requires_grad=False, dtype=torch.long),
                   'attention_mask' : torch.tensor([], requires_grad=False, dtype=torch.long),
                   }
train_labels = []

test_encodings = {'input_ids'      : torch.tensor([], requires_grad=False, dtype=torch.long),
                  'token_type_ids' : torch.tensor([], requires_grad=False, dtype=torch.long),
                  'attention_mask' : torch.tensor([], requires_grad=False, dtype=torch.long),
                  }
test_labels = []

val_encodings = {'input_ids'      : torch.tensor([], requires_grad=False, dtype=torch.long),
                 'token_type_ids' : torch.tensor([], requires_grad=False, dtype=torch.long),
                 'attention_mask' : torch.tensor([], requires_grad=False, dtype=torch.long),
                 }
val_labels = []


train_encodings ={'input_ids'      : [],
                  'token_type_ids' : [],
                  'attention_mask' : [],
                  }
train_labels = []

test_encodings ={'input_ids'      : [],
                  'token_type_ids' : [],
                  'attention_mask' : [],
                  }
test_labels = []

val_encodings ={'input_ids'      : [],
                  'token_type_ids' : [],
                  'attention_mask' : [],
                  }
val_labels = []


train_ids = []
test_ids = []
val_ids = []

removed_chars = "\r\n\t.:~()[]{}"


for author in authors:
    
    #select texts by the authors from the whole gutenberg corpus
    author_ids = metadata.filter((pl.col("author") == author) & (pl.col("subjects").str.contains("Drama")))["id"].to_list()
    
    #splitting in such a way that the model has never seen any parts of the play before
    #otherwise it would probably just learn to recognize the names of the characters
    #it will still ... 
    #I will just have very bad test loss
    
    #TODO: implement name removal
    #https://stackoverflow.com/questions/53534376/removing-names-from-noun-chunks-in-spacy
    
    train_ids += author_ids[:-2]
    test_ids += [author_ids[-2]]
    val_ids += [author_ids[-1]] 
    
    
    for doc_id in author_ids:
        
        file_path = data_dir / "raw" / (doc_id + "_raw.txt")
        
        try:
            with open(file_path, "r") as in_f:
                raw_text = in_f.read()    
        except FileNotFoundError:
            print(f"Warning file not found{file_path}")
            continue
            
        
        #script for removing the gutenberg project headers        
        text = strip_headers(raw_text)         
        
        #removing all people from the people file
        if doc_id in train_ids:
            text = remove_people(people, text)
           
        
        #more custom header and footer stripping
        
        #this book has a longer appendix
        if doc_id =='PG47804':
            
            a = re.search(r'\b(Fußnoten)\b', text)
            text = text[1000:-a.start()]
            
        # for the rest we just strip another 1000 chars
        else: 
            text = text[1000:-1000]

        text.strip(removed_chars) 
        #encoding the data first 
        #since I probably want to use the full 512 tokens without doing any truncating or padding
        encoding = tokenizer(text, return_tensors="pt")
        
        # might be inefficient but fine for now
        # input ids here are token ids not doc ids ...

        encodings = {'input_ids'      : torch.split(encoding['input_ids'], 32, dim=1)[:-1],
                     'token_type_ids' : torch.split(encoding['token_type_ids'], 32, dim=1)[:-1],
                     'attention_mask' : torch.split(encoding['attention_mask'], 32, dim=1)[:-1],
                    }
                     
        
        #not dry it's pretty damp
        if doc_id in train_ids:
            train_encodings = {key : list(encodings[key]) + train_encodings[key] for key in encodings}
            train_labels += [author_mapping[author] for _ in range(len(encodings['input_ids']))]
            # train_labels += [author_mapping[author].clone().detach() for _ in range(len(encoding['input_ids']))]
        elif doc_id in test_ids:
            test_encodings = {key : list(encodings[key]) + test_encodings[key] for key in encodings}
            test_labels += [author_mapping[author] for _ in range(len(encodings['input_ids']))]
            # test_labels += [author_mapping[author].clone().detach() for _ in range(len(encoding['input_ids']))]
        elif doc_id in val_ids:
            val_encodings = {key : list(encodings[key]) + val_encodings[key] for key in encodings}
            val_labels += [author_mapping[author] for _ in range(len(encodings['input_ids']))]
            # val_labels += [author_mapping[author].clone().detach() for _ in range(len(encoding['input_ids']))]


Token indices sequence length is longer than the specified maximum sequence length for this model (4750 > 512). Running this sequence through the model will result in indexing errors


In [8]:
# train_encodings["input_ids"]      = torch.split(torch.cat(train_encodings['input_ids'],dim=0),16, dim=0)
# train_encodings["token_type_ids"] = torch.split(torch.cat(train_encodings['token_type_ids'],dim=0),16, dim=0)
# train_encodings["attention_mask"] = torch.split(torch.cat(train_encodings['attention_mask'],dim=0),16, dim=0)
# train_labels = torch.split(torch.cat(train_labels, dim=0), 16, dim=0)


# test_encodings["input_ids"]      = torch.split(torch.cat(test_encodings['input_ids'],dim=0),16, dim=0)
# test_encodings["token_type_ids"] = torch.split(torch.cat(test_encodings['token_type_ids'],dim=0),16, dim=0)
# test_encodings["attention_mask"] = torch.split(torch.cat(test_encodings['attention_mask'],dim=0),16, dim=0)
# test_labels = torch.split(torch.cat(test_labels, dim=0), 16, dim=0)




# val_encodings["input_ids"]      = torch.split(torch.cat(val_encodings['input_ids'],dim=0),16, dim=0)
# val_encodings["token_type_ids"] = torch.split(torch.cat(val_encodings['token_type_ids'],dim=0),16, dim=0)
# val_encodings["attention_mask"] = torch.split(torch.cat(val_encodings['attention_mask'],dim=0),16, dim=0)
# val_labels = torch.split(torch.cat(val_labels, dim=0), 16, dim=0)


In [9]:
for split in [train_encodings, test_encodings, val_encodings]:
    for key in split.keys():
        split[key] = [ torch.squeeze(seq) for seq in split[key]]

In [10]:
train_dataset = GutenbergDataset(train_encodings, train_labels)
test_dataset = GutenbergDataset(test_encodings, test_labels)
val_dataset = GutenbergDataset(val_encodings, val_labels)



In [11]:

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
eval_loader = DataLoader(val_dataset, batch_size=8)




In [12]:
train_metric = evaluate.load("accuracy")
val_metric = evaluate.load("accuracy")
num_epochs = 7
num_training_steps = num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
optim = AdamW(model.parameters(), lr=4e-5)
lr_scheduler = get_scheduler(name="linear", optimizer=optim, num_warmup_steps=500, num_training_steps=num_training_steps)

  0%|          | 0/17885 [00:00<?, ?it/s]



In [13]:

progress_bar = tqdm(range(num_training_steps))

model.to(device)
losses = []
for epoch in range(num_epochs):
    model.train()
    counter = 0
    for batch in train_loader:
        counter += 1
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optim.step()
        losses.append(loss.item())
        if counter % 100 == 0:
            print(sum(losses) / len(losses))
        lr_scheduler.step()
        progress_bar.update(1)
        
        
    model.eval()
    counter = 0
    for batch in train_loader:
        counter += 1

        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():

            outputs = model(**batch)

        logits = outputs.logits
        

        predictions = torch.argmax(logits, dim=-1)
        

        train_metric.add_batch(predictions=predictions, references=batch["labels"])
        
        if counter == 200:
            break
    print(train_metric.compute())
    
    for batch in eval_loader:

        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():

            outputs = model(**batch)

        logits = outputs.logits
        

        predictions = torch.argmax(logits, dim=-1)
        

        val_metric.add_batch(predictions=predictions, references=batch["labels"])

    print(val_metric.compute())


  0%|          | 0/17885 [00:00<?, ?it/s]

0.6981312268972397
0.6749981753528118
0.6621153472860655
0.6389075046032667
0.6183659367263317
0.6025982542087635
0.590913766260658
0.5822002391517163
0.5705316083878279
0.5595698968544602
0.5506713848493316
0.5436614143103361
0.5405218811218555
0.5367326222147261
0.5321390666812659
0.52604173053056
0.5204451663090902
0.5173937512685856
0.5127881113635866
0.5085218100585044
0.5034001012040036
0.5016161739504473
0.4985649181076366
0.49526323518250137
0.49104332123547795
{'accuracy': 0.871875}
{'accuracy': 0.5029585798816568}
0.48273519953947847
0.4764701568638186
0.4716200244960403
0.4659744054627449
0.4613474572086539
0.4557882199131791
0.45381815048326635
0.4509327900375277
0.4472023490266457
0.4432174834641517
0.4399947926200794
0.43831260121302223
0.4343034257462439
0.43099205215872466
0.4280692298897759
0.42535851676371195
0.42183388470596556
0.4189004816512666
0.415895166051157
0.41262191048200714
0.409852102877038
0.40834976532520995
0.40669959392954896
0.40510565887358485
0.4031

In [14]:
print(predictions)

tensor([0], device='cuda:0')
