In [1]:
from torch import nn, optim
import torch
import time
from load_dataset import load_dataset, add_label_id
import csv
import matplotlib.pyplot as plt
from collections import Counter
from models import T5Classifier, T5EncoderClassifier

In [2]:
label_to_id = {"Web": 0, "International": 1, "Etat": 2, "Wirtschaft": 3, "Panorama": 4,
               "Sport": 5, "Wissenschaft": 6, "Kultur": 7, "Inland": 8}
id_to_label = {0: "Web", 1:"International", 2: "Etat", 3: "Wirtschaft", 4: "Panorama",
               5: "Sport", 6: "Wissenschaft", 7: "Kultur", 8: "Inland"}

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [6]:
loss_fn = nn.CrossEntropyLoss()

In [7]:
model = T5EncoderClassifier(loss_fn=loss_fn, lr=2e-6, use_gradient_clip=True)

In [8]:
optimizer = optim.AdamW(model.parameters(), model.lr)

In [9]:
simplified_labels = ["Web", "International", "Etat", "Wirtschaft", "Panorama", "Sport", "Wissenschaft", 
                     "Kultur", "Heimat"]

In [10]:
decoder_attentionmask = [1 for _ in range(len(label_to_id))]
decoder_inputs = [model.tokenizer(simplified_label).input_ids[0] for simplified_label in simplified_labels]
def tokenize_function(examples):
    if model.is_transformer:
        input_tokenized = model.tokenizer(examples["text"], padding="max_length", max_length=512, truncation=True)
        return {"input_ids": input_tokenized.input_ids, "input_attention_mask": input_tokenized.attention_mask,
                "decoder_ids": decoder_inputs, "decoder_attention_mask": decoder_attentionmask }
    else:
        return model.tokenizer(examples["text"], padding="max_length", max_length=512, truncation=True)

In [11]:
def add_label_id(example):
    return {"label_id": label_to_id[example["label"]]}

In [12]:
train_ds, test_ds = load_dataset("../German_newspaper_articles/10kGNAD/train.csv", 
                                 "../German_newspaper_articles/10kGNAD/test.csv")

In [13]:
train_ds = train_ds.map(tokenize_function)
test_ds = test_ds.map(tokenize_function)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Map:   0%|          | 0/9245 [00:00<?, ? examples/s]

Map:   0%|          | 0/1028 [00:00<?, ? examples/s]

In [14]:
train_ds = train_ds.map(add_label_id)
test_ds = test_ds.map(add_label_id)

Map:   0%|          | 0/9245 [00:00<?, ? examples/s]

Map:   0%|          | 0/1028 [00:00<?, ? examples/s]

In [None]:
train_ds.set_format("torch", device="mps")
model.to("mps")

In [16]:
def eval(val_data):
    model.eval()
    correct = 0
    all = 0
    for data in val_data:
        label_id = torch.unsqueeze(data['label_id'],0)
        if model.is_transformer:
            output = model(torch.unsqueeze(data['input_ids'], 0), torch.unsqueeze(data['input_attention_mask'], 0),
                        torch.unsqueeze(data['decoder_ids'], 0), torch.unsqueeze(data['decoder_attention_mask'], 0))
        else:
            output = model(torch.unsqueeze(data['input_ids'], 0), torch.unsqueeze(data['attention_mask'], 0))
                           
        output = torch.argmax(output)
        if label_id == output:
            correct +=1
        all += 1
    print(f"Eval accuracy: {(correct/all)*100:.2f}%")
    return (correct/all)*100

In [None]:
batch_size = 8
batch_index = 0
running_loss = 0
epochs = 10
train_eval = train_ds.train_test_split(test_size=0.2, shuffle=True)
loss_ls = []
accuracy_ls = []
start_time = time.perf_counter()
for i in range(epochs):
    print(f"Epoch: {i}")
    train_eval = train_eval.shuffle()
    for j in range(int(len(train_eval["train"])/batch_size)):
        model.train()
        batch = train_eval['train'][batch_index:batch_index+batch_size]
        batch_index += batch_size
        optimizer.zero_grad()
        if model.is_transformer:
            output = model(batch['input_ids'], batch['input_attention_mask'], batch['decoder_ids'], batch['decoder_attention_mask'])
        else:
            output = model(batch["input_ids"], batch["attention_mask"])

        loss = model.loss_fn(output, batch["label_id"])
        loss.backward()
        if model.use_gradient_clip:
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        running_loss += loss.item()
        if j % 50 == 49:
            print(f"loss: {running_loss/50}")
            loss_ls.append(running_loss/50)
            running_loss = 0
        if j % 200 == 199:
            accuracy_ls.append(eval(train_eval["test"]))

    batch_index = 0
       
end_time = time.perf_counter()
duration = (end_time - start_time)/60
print(f"Training took {duration:0.4f} minutes")

In [None]:
metrics = {"loss": loss_ls,
           "accuracy": accuracy_ls}
path_to_save_metrics = "./t5_results/transformer/metrics"
with open(path_to_save_metrics + f"/metrics_{epochs}_epochs.csv", "w", encoding="utf-8") as file:
    file_writer = csv.DictWriter(file, fieldnames=metrics.keys())
    file_writer.writeheader()
    file_writer.writerow(metrics)

In [None]:
loss_steps = [(i+1)*50 for i in range(len(loss_ls))]
accuracy_steps = [(i+1)*200 for i in range(len(accuracy_ls))]

fig, ax1 = plt.subplots()

ax2 = ax1.twinx()
plt.gca().ticklabel_format(axis='both', style='plain', useOffset=False)
ax1.plot(loss_steps, loss_ls, 'r-')
ax2.plot(accuracy_steps, accuracy_ls, 'g-')
ax1.set_xlabel('Steps')
ax1.set_ylabel('Loss', color='r')
ax2.set_ylabel('Accuracy', color='g')

plt.title('T5 base')
txt = f"lr={model.lr}, batch_size={batch_size}, epochs={epochs}, gradient_clip=1.0 \n Training duration: {duration:0.2f} minutes"
plt.figtext(0.5, -0.05, txt, wrap=True, horizontalalignment='center', fontsize=12)

plt.savefig(path_to_save_metrics + f"/graph_{epochs}_epochs.png")
plt.show()

In [None]:
correct = 0
correct_dict = {"Web": 0, "International": 0, "Etat": 0, "Wirtschaft": 0, "Panorama": 0, "Sport": 0, "Wissenschaft": 0, "Kultur": 0,
                "Inland": 0}
def test():
    model.eval()
    test_ds.set_format("torch", device=device)
    start_time = time.perf_counter()
    for data in test_ds:
        label_id = torch.unsqueeze(data['label_id'], 0)
        if model.is_transformer:
            output = model(torch.unsqueeze(data['input_ids'], 0), torch.unsqueeze(data['input_attention_mask'], 0),
                        torch.unsqueeze(data['decoder_ids'], 0), torch.unsqueeze(data['decoder_attention_mask'], 0))
        else:
            output = model(torch.unsqueeze(data['input_ids'], 0), torch.unsqueeze(data['attention_mask'], 0))

        output = torch.argmax(output)
        if label_id == output:
            global correct_dict
            correct_dict[id_to_label[output.item()]] += 1
            global correct
            correct +=1
    end_time = time.perf_counter()
    print(f"Test took {(end_time - start_time)/60:0.4f} minutes")
    return (correct/len(test_ds))*100 

In [None]:
accuracy_test = test()
accuracy_test

In [None]:
with open(path_to_save_metrics + "/test_evaluation.csv", "w", encoding="utf-8") as file:
    file_writer = csv.DictWriter(file, fieldnames=correct_dict.keys())
    file_writer.writeheader()
    file_writer.writerow(correct_dict)

In [None]:
label_counts = Counter(test_ds["label"])
label_counts

In [None]:
labels = list(label_counts.keys())
differences = dict()
for label in labels:
    differences[label] = correct_dict[label]/label_counts[label]

In [None]:
difference_values = [value*100 for value in differences.values()]
difference_labels = [value for value in differences.keys()]
fig, ax = plt.subplots()

xs = range(len(difference_labels))
ys = [difference_values[x] for x in xs]

ax.bar(difference_labels, ys, 0.6)
plt.title("correct per category")
plt.xlabel("category")
plt.ylabel("accuracy in %")
plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
plt.savefig(path_to_save_metrics + "/test_evaluation.png")

plt.show()

In [None]:
path_to_model = "./t5_results/model/model.pt"
torch.save(model, path_to_model)