In [15]:
from torch import nn, optim
import torch
import time
from load_dataset import load_dataset, add_label_id
import csv
import matplotlib.pyplot as plt
from collections import Counter
from models import BertTextClassifier

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
device = "mps"

Using device: cpu


In [17]:
loss_fn = nn.CrossEntropyLoss()

In [18]:
model = BertTextClassifier(loss_fn=loss_fn, lr=2e-7, use_gradient_clip=True)

In [19]:
path_to_model = "./bert_results/model/model.pt"
model = torch.load(path_to_model, map_location="mps")

In [20]:
optimizer = optim.AdamW(model.parameters(), model.lr)

In [21]:
def tokenize_function(examples):
    return model.tokenizer(examples["text"], padding="max_length", truncation=True)

In [22]:
train_ds, test_ds = load_dataset("../German_newspaper_articles/10kGNAD/train.csv", 
                                 "../German_newspaper_articles/10kGNAD/test.csv")

In [23]:
train_ds = train_ds.map(tokenize_function)
test_ds = test_ds.map(tokenize_function)



Map:   0%|          | 0/9245 [00:00<?, ? examples/s]

In [None]:
train_ds = train_ds.map(add_label_id)
test_ds = test_ds.map(add_label_id)

Map:   0%|          | 0/9245 [00:00<?, ? examples/s]

Map:   0%|          | 0/1028 [00:00<?, ? examples/s]

In [None]:
train_ds.set_format("torch", device="mps")
model.to("mps")

BertTextClassifier(
  (loss_fn): CrossEntropyLoss()
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31102, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (L

In [None]:
def eval(val_data):
    model.eval()
    correct = 0
    all = 0
    for data in val_data:
        label_id = torch.unsqueeze(data['label_id'],0)
        output = model(torch.unsqueeze(data['input_ids'], 0), torch.unsqueeze(data['attention_mask'], 0))                          
        output = torch.argmax(output)
        if label_id == output:
            correct +=1
        all += 1
    print(f"Eval accuracy: {(correct/all)*100:.2f}%")
    return (correct/all)*100

In [None]:
batch_size = 8
batch_index = 0
running_loss = 0
epochs = 10
train_eval = train_ds.train_test_split(test_size=0.2, shuffle=True)
loss_ls = []
accuracy_ls = []
start_time = time.perf_counter()
for i in range(epochs):
    print(f"Epoch: {i}")
    train_eval = train_eval.shuffle()
    for j in range(int(len(train_eval["train"])/batch_size)):
        model.train()
        batch = train_eval['train'][batch_index:batch_index+batch_size]
        batch_index += batch_size
        optimizer.zero_grad()

        output = model(batch["input_ids"], batch["attention_mask"])
        loss = model.loss_fn(output, batch["label_id"])
        loss.backward()
        if model.use_gradient_clip:
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        running_loss += loss.item()
        if j % 50 == 49:
            print(f"loss: {running_loss/50}")
            loss_ls.append(running_loss/50)
            running_loss = 0
        if j % 200 == 199:
            accuracy_ls.append(eval(train_eval["test"]))

    batch_index = 0
       
end_time = time.perf_counter()
print(f"Training took {(end_time - start_time)/60:0.4f} minutes")

In [None]:
metrics = {"loss": loss_ls,
           "accuracy": accuracy_ls}
path_to_save_metrics = "./bert_results/metrics"
with open(path_to_save_metrics + "/metrics_10_epochs.csv", "w", encoding="utf-8") as file:
    file_writer = csv.DictWriter(file, fieldnames=metrics.keys())
    file_writer.writeheader()
    file_writer.writerow(metrics)

In [None]:
loss_steps = [(i+1)*50 for i in range(len(loss_ls))]
accuracy_steps = [(i+1)*200 for i in range(len(accuracy_ls))]

fig, ax1 = plt.subplots()

ax2 = ax1.twinx()
plt.gca().ticklabel_format(axis='both', style='plain', useOffset=False)
ax1.plot(loss_steps, loss_ls, 'r-')
ax2.plot(accuracy_steps, accuracy_ls, 'g-')
ax1.set_xlabel('Steps')
ax1.set_ylabel('Loss', color='r')
ax2.set_ylabel('Accuracy', color='g')

plt.title('T5 base')
txt = "lr=2e-5, batch_size=8, epochs=10, gradient_clip=1.0 \n Training duration: 293.78 minutes"
plt.figtext(0.5, -0.05, txt, wrap=True, horizontalalignment='center', fontsize=12)

plt.savefig(path_to_save_metrics + "/graph_20_epochs.png")
plt.show()

In [None]:
id_to_label = {0: "Web", 1:"International", 2: "Etat", 3: "Wirtschaft", 4: "Panorama",
               5: "Sport", 6: "Wissenschaft", 7: "Kultur", 8: "Inland"}

In [None]:
correct = 0
correct_dict = {"Web": 0, "International": 0, "Etat": 0, "Wirtschaft": 0, "Panorama": 0, "Sport": 0, "Wissenschaft": 0, "Kultur": 0,
                "Inland": 0}
wrong = []
def test():
    model.eval()
    test_ds.set_format("torch", device=device)
    start_time = time.perf_counter()
    for data in test_ds:
        label_id = torch.unsqueeze(data['label_id'], 0)
        if model.is_transformer:
            output = model(torch.unsqueeze(data['input_ids'], 0), torch.unsqueeze(data['input_attention_mask'], 0),
                        torch.unsqueeze(data['decoder_ids'], 0), torch.unsqueeze(data['decoder_attention_mask'], 0))
        else:
            output = model(torch.unsqueeze(data['input_ids'], 0), torch.unsqueeze(data['attention_mask'], 0))
            
        output = torch.argmax(output)
        if label_id == output:
            global correct_dict
            correct_dict[id_to_label[output.item()]] += 1
            global correct
            correct +=1
        else:
            pred = {"sample": data["text"], "prediction": id_to_label[output.item()], "label": data["label"]}
            wrong.append(pred)
    end_time = time.perf_counter()
    print(f"Test took {(end_time - start_time)/60:0.4f} minutes")
    return (correct/len(test_ds))*100 

In [None]:
accuracy_test = test()
accuracy_test

In [None]:
with open(path_to_save_metrics + "/test_evaluation.csv", "w", encoding="utf-8") as file:
    file_writer = csv.DictWriter(file, fieldnames=correct_dict.keys())
    file_writer.writeheader()
    file_writer.writerow(correct_dict)

In [None]:
label_counts = Counter(test_ds["label"])
label_counts

In [None]:
labels = list(label_counts.keys())
differences = dict()
for label in labels:
    differences[label] = correct_dict[label]/label_counts[label]

In [None]:
difference_values = [value*100 for value in differences.values()]
difference_labels = [value for value in differences.keys()]
fig, ax = plt.subplots()

xs = range(len(difference_labels))
ys = [difference_values[x] for x in xs]

ax.bar(difference_labels, ys, 0.6)
plt.title("correct per category")
plt.xlabel("category")
plt.ylabel("accuracy in %")
plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
plt.savefig(path_to_save_metrics + "/test_evaluation.png")

plt.show()

In [None]:
path_to_model = "./bert_results/model/model.pt"
torch.save(model, path_to_model)