<a href="https://colab.research.google.com/github/ipavlopoulos/paremia/blob/main/bert-gr-c.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GrBERT on Greek Proverbs


In [4]:
%%capture
!pip install transformers
from transformers import BertModel, BertTokenizer
model_name = 'nlpaueb/bert-base-greek-uncased-v1'
tokenizer = BertTokenizer.from_pretrained(model_name)

In [None]:
f1_scores = []

* Using three splits, re-run the notebook by changing the seed and saving the scores

In [87]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import *
from pathlib import Path

# load the data
corpus_path = "data/balanced_corpus.csv"
if not Path(corpus_path).exists():
  corpus_path = 'https://raw.githubusercontent.com/ipavlopoulos/paremia/main/data/balanced_corpus.csv'
balanced_corpus = pd.read_csv(corpus_path, index_col=0)
# change the seed to restart
i = 2
seed = 2023+i
train, test = train_test_split(balanced_corpus, test_size=0.05, random_state=seed)
train, dev = train_test_split(train, test_size=test.shape[0], random_state=seed)

In [88]:
import torch
from sklearn.preprocessing import OneHotEncoder

# the areas that will serve as target label indices
idx2loc = {i:a for i,a in enumerate(train.area.unique())}
loc2idx = {idx2loc[i]:i for i in idx2loc}

class Dataset(torch.utils.data.Dataset):
    def __init__(self, df, max_length = 32):
        self.max_length = max_length
        self.labels = df.area.apply(lambda a: loc2idx[a])
        self.labels = np.array(self.labels.values)
        self.labels = np.reshape(self.labels, (self.labels.shape[0], 1))
        self.labels = OneHotEncoder(sparse_output=False).fit_transform(self.labels)
        self.texts = np.array(df.text.apply(lambda txt: tokenizer(txt, padding='max_length', max_length = self.max_length, truncation=True, return_tensors="pt")).values)

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        batch_texts = self.texts[idx]
        batch_labels = self.labels[idx]
        return batch_texts, batch_labels

In [89]:
from torch import nn

class BertClassifier(nn.Module):

    def __init__(self, dropout=0.1, num_classes=1):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.linear1 = nn.Linear(768, 128, bias=True)
        self.norm = nn.BatchNorm1d(128)
        self.linear2 = nn.Linear(128, num_classes, bias=True)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
        x = pooled_output
        x = self.dropout(x)
        x = self.relu(self.linear1(x))
        x = self.norm(x)
        x = self.linear2(x)
        return x

In [90]:
from torch.optim import Adam
from tqdm import tqdm

def validate(model, dataloader, device="cpu", criterion=nn.CrossEntropyLoss()):
    predictions, gold_labels = [], []
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_id, (val_input, val_label) in enumerate(dataloader):
            val_label = val_label.to(device)
            mask = val_input['attention_mask'].to(device)
            input_id = val_input['input_ids'].squeeze(1).to(device)
            output = model(input_id, mask)
            batch_loss = criterion(output, val_label)
            gold = np.argmax(val_label.cpu().detach().numpy(), axis=1)
            pred = np.argmax(output.cpu().detach().numpy(), axis=1)
            predictions.extend(pred)
            gold_labels.extend(gold)
            val_loss += batch_loss.item()
    return predictions, gold_labels, val_loss/batch_id

def finetune(model, train_data, val_data, learning_rate=2e-5, epochs=10, criterion=nn.CrossEntropyLoss(),
             batch_size=32, max_length=32, patience=2):

    train_losses = []
    val_losses = []

    train_dataloader = torch.utils.data.DataLoader(Dataset(train_data, max_length=max_length),
                                                   batch_size=batch_size, shuffle=True, drop_last=False)
    val_dataloader = torch.utils.data.DataLoader(Dataset(val_data, max_length=max_length),
                                                 batch_size=batch_size, drop_last=False)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    optimizer = Adam(model.parameters(), lr=learning_rate)
    model.to(device)
    lowest_loss = 1000
    best_f1 = 0
    best_epoch = 0
    epochs_not_improving = 0
    for epoch_num in range(epochs):
            total_acc_train = 0
            total_loss_train = 0
            for batch_id, (inputs, labels) in tqdm(enumerate(train_dataloader)):
                model.train()
                output = model(inputs['input_ids'].squeeze(1).to(device),
                               inputs['attention_mask'].to(device))
                batch_loss = criterion(output.to(device), labels.to(device))
                total_loss_train += batch_loss.item()

                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()
            train_losses.append(total_loss_train/(batch_id+1))

            predictions, gold_labels, val_loss = validate(model, val_dataloader, device, criterion)
            f1 = f1_score(gold_labels, predictions, average='macro', zero_division=0)
            val_losses.append(val_loss)
            if f1 > best_f1:
                print(f"New best epoch found: {epoch_num} (f1: {f1:.3f})!")
                best_f1 = f1
                best_epoch = epoch_num
                torch.save(model.state_dict(), "checkpoint.pt")
                epochs_not_improving = 0
            else:
                epochs_not_improving += 1
                if epochs_not_improving >= patience:
                    model.load_state_dict(torch.load("checkpoint.pt"))
                    print('Patience is up, restoring the best model and exiting...')
                    break
            print(
                f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train/batch_id: .3f} \
                | Val Loss: {val_loss: .3f} (best epoch: {best_epoch} w/f1: {best_f1:.3f})')
    model.eval()
    return model, train_losses, val_losses

In [None]:
model, train_losses, val_losses = finetune(BertClassifier(num_classes=len(loc2idx)),
                                           train, dev,
                                           epochs=100,
                                           patience=5,
                                           batch_size=64,
                                           max_length=32)

In [None]:
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

In [94]:
torch.save(model.state_dict(), f"bert-gr-c--seed-{i}.pt")

* Upload to the cloud

In [None]:
!gsutil cp bert-gr-c-seed0.pt gs://{bucket_name}/ # up
!gsutil cp bert-gr-c-seed1.pt gs://{bucket_name}/ # up
!gsutil cp bert-gr-c-seed2.pt gs://{bucket_name}/ # up

# Assessing three models
* Three splits

In [None]:
# to use these, initialise the bucket_name and project_id
bucket_name, project_id = None, None # <== set these properly
!gcloud config set project {project_id}
!gsutil cp f'gs://{bucket_name}/proverbs/bert-gr-c-seed0.pt' ./
!gsutil cp f'gs://{bucket_name}/proverbs/bert-gr-c-seed1.pt' ./
!gsutil cp f'gs://{bucket_name}/proverbs/bert-gr-c-seed2.pt' ./

In [111]:
from sklearn.metrics import f1_score
f1_scores = []
for i in range(3):
  seed = 2023+i
  train, test = train_test_split(balanced_corpus, test_size=0.05, random_state=seed)
  train, dev = train_test_split(train, test_size=test.shape[0], random_state=seed)
  # the areas that will serve as target label indices
  idx2loc = {i:a for i,a in enumerate(train.area.unique())}
  loc2idx = {idx2loc[i]:i for i in idx2loc}
  labels = test.area.unique()
  test_dataloader = torch.utils.data.DataLoader(Dataset(test), batch_size=1, drop_last=False)
  model = BertClassifier(num_classes=len(loc2idx))
  model.load_state_dict(torch.load(f"bert-gr-c-seed-{i}.pt"))
  p,l,_ = validate(model.to("cpu"), test_dataloader, "cpu")
  f1_scores.append(dict(zip(labels, f1_score([idx2loc[i] for i in l], [idx2loc[i] for i in p], average=None, labels=labels))))

In [112]:
results = pd.DataFrame({i: [f1_scores[i][label] for label in labels] for i in range(3)}, index=labels)
results.agg(['mean', 'sem'], 1)

Unnamed: 0,mean,sem
Αχαΐα,0.457997,0.021902
Μακεδονία,0.160053,0.06334
Ιωάννινα,0.279432,0.042224
Νάξος,0.324187,0.083406
Κεφαλληνία,0.250654,0.014426
Κύπρος,0.802132,0.021637
Ανατολική Θράκη,0.274183,0.019503
Εύβοια,0.210071,0.033197
Ήπειρος,0.112052,0.034581
Σκύρος,0.538715,0.049389


In [114]:
print(f'Overall F1: {results.mean(1).mean():.2f}')

Overall F1: 0.30


In [129]:
loc_name = {'Ρούμελη':'Roumeli', 'Κοζάνη':'Kozani', 'Κως':'Kos', 'Αδριανούπολη':'Adrian.', 'Νάουσα':'Naousa', 'Σέρρες':'Serres', 'Σίφνος': 'Sifnos', 'Ήπειρος':'Epirus', 'Αιτωλία':'Etolia', 'Αμοργός':'Amorgos', 'Ανατολική Θράκη': 'East Thrace', 'Αρκαδία':'Arcadia', 'Αχαΐα':'Achaia', 'Επτάνησος':'Eptanisos', 'Εύβοια':'Eyvoia', 'Θεσπρωτία':'Thesprotia',  'Θράκη': 'Thrace', 'Ιωάννινα':'Ioannina', 'Κάρπαθος':'Karpathos', 'Κεφαλληνία':'Kefalinia', 'Κρήτη':'Crete', 'Κύπρος':'Cyprus', 'Λέσβος':'Lesvos', 'Λακωνία':'Laconia', 'Μακεδονία':'Maced.', 'Μικρά Ασία':'Asia Minor', 'Νάξος':'Naxos', 'Πόντος':'Pontos', 'Ρόδος':'Rodos', 'Σκύρος':'Skyros'}
results.set_index(results.index.map(lambda x: loc_name[x])).mean(1).to_dict()

{'Achaia': 0.4579969007421907,
 'Maced.': 0.16005291005291003,
 'Ioannina': 0.2794318792075741,
 'Naxos': 0.3241873430552676,
 'Kefalinia': 0.2506544138123085,
 'Cyprus': 0.8021322378716746,
 'East Thrace': 0.27418268006503305,
 'Eyvoia': 0.21007081038552325,
 'Epirus': 0.11205216105397593,
 'Skyros': 0.5387147335423198,
 'Amorgos': 0.3191919191919192,
 'Laconia': 0.13324236517218976,
 'Asia Minor': 0.11595103991417355,
 'Eptanisos': 0.3215130023640662,
 'Arcadia': 0.10771604938271605,
 'Pontos': 0.6824451570214283,
 'Thesprotia': 0.15125205428658514,
 'Rodos': 0.2710079518590157,
 'Crete': 0.14789272030651343,
 'Etolia': 0.39267399267399267,
 'Thrace': 0.12055555555555557,
 'Karpathos': 0.3386713175836373,
 'Lesvos': 0.42782632256316466}

In [130]:
results.set_index(results.index.map(lambda x: loc_name[x])).sem(1).to_dict()

{'Achaia': 0.021902044304510204,
 'Maced.': 0.06334031689848084,
 'Ioannina': 0.04222383785420334,
 'Naxos': 0.08340611811731657,
 'Kefalinia': 0.014426483330086186,
 'Cyprus': 0.02163703911850188,
 'East Thrace': 0.01950255380182251,
 'Eyvoia': 0.03319747947552477,
 'Epirus': 0.03458134799897901,
 'Skyros': 0.04938896096017988,
 'Amorgos': 0.01414141414141416,
 'Laconia': 0.008168124634515671,
 'Asia Minor': 0.03177360693520247,
 'Eptanisos': 0.04026414913672579,
 'Arcadia': 0.03330045246127889,
 'Pontos': 0.0106410435854305,
 'Thesprotia': 0.01180732905121293,
 'Rodos': 0.045754714928097524,
 'Crete': 0.08988411022747426,
 'Etolia': 0.059750764847190295,
 'Thrace': 0.10044499754026677,
 'Karpathos': 0.027400649172465447,
 'Lesvos': 0.03408459111807734}