In [1]:
#%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [1]:
from collections import Counter
from glob import glob
from itertools import islice

import numpy as np
import pandas as pd
import scipy
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.preprocessing import label_binarize
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from era_data import TabletPeriodDataset
from era_model import SimpleCNN

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Hyperparameters

In [3]:
IMG_DIR = 'output/images_preprocessed'
VERSION_NAME = 'period_clf_bs16_lr1e-05_20epochs-vanillaCNN-94936_samples-preprocessed-masked-April16-80-10-10_train-test-val'

BATCH_SIZE=16
CKPT_FN = glob(f'lightning_logs/{VERSION_NAME}/checkpoints/*')[0]
CKPT_FN

'lightning_logs/period_clf_bs16_lr1e-05_20epochs-vanillaCNN-94936_samples-preprocessed-masked-April16-80-10-10_train-test-val/checkpoints/epoch=8-step=42723.ckpt'

# Load data and model

In [4]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
num_classes

22

In [5]:
def collate_fn(batch):
    data = torch.stack([torch.from_numpy(sample[1]).unsqueeze(0) for sample in batch])
    labels = torch.tensor([sample[2] for sample in batch])

    return data, labels

In [6]:
test_ids = pd.read_csv(f'output/clf_ids/period-test-{VERSION_NAME}.csv', header=None)[0].astype(str)

In [7]:
ds_test = TabletPeriodDataset(IDS=test_ids, mask = True, IMG_DIR=IMG_DIR)

Filtering 94936 IDS down to provided 9494...


In [8]:
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False, num_workers=4)

In [9]:
%time model = SimpleCNN.load_from_checkpoint(CKPT_FN, num_classes=num_classes)

CPU times: user 2.32 s, sys: 4.7 s, total: 7.02 s
Wall time: 29 s


In [10]:
model.to(device);

In [11]:
model.eval();

# Calculate Predictions

In [12]:
def dl2data(dl, MAX_N=None, device='cuda'):
    logits = []
    y_true = []  # This will hold the period indices
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        gen = tqdm(islice(dl, MAX_N), total=(MAX_N if MAX_N is not None else len(dl)))
        for img, period_index in gen:
            try:
                y_true.append(period_index.cpu().numpy())  # Append period indices
                logits.append(model(img.to(device)).cpu().numpy())  # Process the image through the model
            except Exception as e:
                print(f"Error processing batch: {e}")
    y_true = np.hstack(y_true)
    logits = np.vstack(logits)

    return logits, y_true

In [13]:
logits, y_true = dl2data(dl_test)

100%|██████████| 580/580 [00:20<00:00, 28.51it/s]


In [14]:
y_true.shape, logits.shape

((9279,), (9279, 22))

In [15]:
y_pred = logits.argmax(axis=-1)
y_prob = scipy.special.softmax(logits, axis=-1)

In [16]:
y_pred.shape, y_prob.shape

((9279,), (9279, 22))

In [17]:
(y_pred == y_true).mean()

0.7141933397995474

In [18]:
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.33      0.02      0.03        57
           1       0.82      0.86      0.84      2740
           2       0.89      0.96      0.92      2328
           3       0.51      0.74      0.60      1705
           4       0.47      0.29      0.35       632
           5       0.48      0.33      0.39       479
           6       0.62      0.40      0.48       302
           7       0.57      0.33      0.42       173
           8       0.59      0.14      0.22       200
           9       0.73      0.51      0.60       145
          10       0.45      0.18      0.26       150
          11       0.85      0.47      0.60       166
          12       0.71      0.34      0.46        35
          13       0.00      0.00      0.00         1
          14       0.67      0.46      0.54        48
          15       0.42      0.31      0.36        42
          16       0.00      0.00      0.00        15
          17       0.88    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [19]:
EARLY_BRONZE = {
    'Old Akkadian', 'Ur III',
    'ED IIIb', 'Uruk III',
    'Proto-Elamite', 'Lagash II',
    'Ebla', 'ED IIIa', 'ED I-II',
    'Uruk IV', 'Linear Elamite',
    'Harappan'
    
}
MID_LATE_BRONZE = {
    'Early Old Babylonian',
    'Old Babylonian', 'Old Assyrian',
    'Middle Babylonian', 'Middle Assyrian',
    'Middle Elamite', 'Hittite'
}
IRON = {
    'Neo-Babylonian', 'Neo-Assyrian',
    'Achaemenid', 'Hellenistic',
    'Neo-Elamite'
}
ERA_MAP = {
    **{K: 'EB' for K in EARLY_BRONZE},
    **{K: 'MLB' for K in MID_LATE_BRONZE},
    **{K: 'I' for K in IRON},
}

In [20]:
def explain(period):
    return f'{period} ({ERA_MAP.get(period, "?")})'

In [21]:
idx2period = {v: k for k, v in TabletPeriodDataset.PERIOD_INDICES.items()}
idx2period[0] = 'other'

In [22]:
# let's just use classes with support >=10, everything else goes to 0: other
COMMON_LABELS = list({k for k, v in Counter(y_true).items() if v >= 40})
print(f'Common labels: ({len(COMMON_LABELS)})')
[(i, explain(idx2period[i])) for i in COMMON_LABELS]

Common labels: (14)


[(0, 'other (?)'),
 (1, 'Ur III (EB)'),
 (2, 'Neo-Assyrian (I)'),
 (3, 'Old Babylonian (MLB)'),
 (4, 'Middle Babylonian (MLB)'),
 (5, 'Neo-Babylonian (I)'),
 (6, 'Old Akkadian (EB)'),
 (7, 'Achaemenid (I)'),
 (8, 'Early Old Babylonian (MLB)'),
 (9, 'ED IIIb (EB)'),
 (10, 'Middle Assyrian (MLB)'),
 (11, 'Old Assyrian (MLB)'),
 (14, 'Lagash II (EB)'),
 (15, 'Ebla (EB)')]

In [23]:
y_true_c = y_true.copy()
y_true_c[~np.isin(y_true, COMMON_LABELS)] = 0
print((~np.isin(y_true, COMMON_LABELS)).mean(), 'changed to "other"')

0.012070266192477638 changed to "other"


In [24]:
y_pred_c = y_pred.copy()
y_pred_c[~np.isin(y_pred, COMMON_LABELS)] = 0
print((~np.isin(y_pred, COMMON_LABELS)).mean(), 'changed to "other"')

0.005388511693070374 changed to "other"


In [25]:
indices_c = list(set(y_true_c) | set(y_pred_c))
print(len(indices_c))
print(indices_c)
PERIOD_LABELS_C = [explain(idx2period[i]) for i in indices_c]
print(PERIOD_LABELS_C)

14
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15]
['other (?)', 'Ur III (EB)', 'Neo-Assyrian (I)', 'Old Babylonian (MLB)', 'Middle Babylonian (MLB)', 'Neo-Babylonian (I)', 'Old Akkadian (EB)', 'Achaemenid (I)', 'Early Old Babylonian (MLB)', 'ED IIIb (EB)', 'Middle Assyrian (MLB)', 'Old Assyrian (MLB)', 'Lagash II (EB)', 'Ebla (EB)']


In [26]:
print(classification_report(y_true_c, y_pred_c, target_names=PERIOD_LABELS_C))

                            precision    recall  f1-score   support

                 other (?)       0.51      0.16      0.24       169
               Ur III (EB)       0.82      0.86      0.84      2740
          Neo-Assyrian (I)       0.89      0.96      0.92      2328
      Old Babylonian (MLB)       0.51      0.74      0.60      1705
   Middle Babylonian (MLB)       0.47      0.29      0.35       632
        Neo-Babylonian (I)       0.48      0.33      0.39       479
         Old Akkadian (EB)       0.62      0.40      0.48       302
            Achaemenid (I)       0.57      0.33      0.42       173
Early Old Babylonian (MLB)       0.59      0.14      0.22       200
              ED IIIb (EB)       0.73      0.51      0.60       145
     Middle Assyrian (MLB)       0.45      0.18      0.26       150
        Old Assyrian (MLB)       0.85      0.47      0.60       166
            Lagash II (EB)       0.67      0.46      0.54        48
                 Ebla (EB)       0.42      0.31

In [27]:
Y = label_binarize(y_true, classes=np.unique(y_true))
n_classes = Y.shape[1]

auc_scores = []
for i in range(n_classes):
    auc = roc_auc_score(Y[:, i], y_prob[:, i])
    auc_scores.append(auc)

macro_ovr_auc = np.mean(auc_scores)
print(f"Macro-OvR-AUC: {macro_ovr_auc}")

Macro-OvR-AUC: 0.8836004991593627
