In [1]:
#%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [4]:
from collections import Counter
from glob import glob
from itertools import islice

import numpy as np
import pandas as pd
import scipy
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from era_data import TabletPeriodDataset
from era_model import SimpleCNN

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Hyperparameters

In [5]:
IMG_DIR = 'output/images_preprocessed'
VERSION_NAME = 'period_clf_bs16_lr5e-05_8epochs-vanillaCNN-94936_samples-preprocessed-masked-March29_1000test'
BATCH_SIZE=16
CKPT_FN = glob(f'lightning_logs/{VERSION_NAME}/checkpoints/*')[0]
CKPT_FN

'lightning_logs/period_clf_bs16_lr5e-05_8epochs-vanillaCNN-94936_samples-preprocessed-masked-March29_1000test/checkpoints/epoch=7-step=46472.ckpt'

# Load data and model

In [6]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES) + 2
num_classes

24

In [7]:
def collate_fn(batch):
    data = torch.stack([torch.from_numpy(sample[1]).unsqueeze(0) for sample in batch])
    labels = torch.tensor([sample[2] for sample in batch])

    return data, labels

In [8]:
test_ids = pd.read_csv(f'output/clf_ids/period-test-{VERSION_NAME}.csv', header=None)[0].astype(str)

In [16]:
ds_test = TabletPeriodDataset(IDS=test_ids, mask = True, IMG_DIR=IMG_DIR)

Filtering 94936 IDS down to provided 1000...


In [17]:
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False, num_workers=4)

In [18]:
%time model = SimpleCNN.load_from_checkpoint(CKPT_FN, num_classes=num_classes)

CPU times: user 2.14 s, sys: 1.33 s, total: 3.47 s
Wall time: 3.08 s


In [19]:
model.to(device);

In [20]:
model.eval();

# Calculate Predictions

In [21]:
def dl2data(dl, MAX_N=None, device='cuda'):
    logits = []
    y_true = []  # This will hold the period indices
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        gen = tqdm(islice(dl, MAX_N), total=(MAX_N if MAX_N is not None else len(dl)))
        for img, period_index in gen:
            try:
                y_true.append(period_index.cpu().numpy())  # Append period indices
                logits.append(model(img.to(device)).cpu().numpy())  # Process the image through the model
            except Exception as e:
                print(f"Error processing batch: {e}")
    y_true = np.hstack(y_true)
    logits = np.vstack(logits)

    return logits, y_true

In [22]:
logits, y_true = dl2data(dl_test)

100%|██████████| 62/62 [00:02<00:00, 23.35it/s]


In [23]:
y_true.shape, logits.shape

((983,), (983, 24))

In [24]:
y_pred = logits.argmax(axis=-1)
y_prob = scipy.special.softmax(logits, axis=-1)

In [25]:
y_pred.shape, y_prob.shape

((983,), (983, 24))

In [26]:
(y_pred == y_true).mean()

0.7202441505595117

In [27]:
from sklearn.metrics import classification_report

print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         3
           1       0.78      0.86      0.82       309
           2       0.92      0.94      0.93       250
           3       0.52      0.72      0.60       179
           4       0.55      0.32      0.40        56
           5       0.49      0.41      0.44        49
           6       0.50      0.32      0.39        22
           7       0.60      0.16      0.25        19
           8       0.60      0.10      0.18        29
           9       0.83      0.83      0.83        12
          10       0.33      0.07      0.11        15
          11       0.85      0.50      0.63        22
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00         0
          14       0.62      0.83      0.71         6
          15       0.00      0.00      0.00         5
          16       0.00      0.00      0.00         2
          17       1.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [28]:
EARLY_BRONZE = {
    'Old Akkadian', 'Ur III',
    'ED IIIb', 'Uruk III',
    'Proto-Elamite', 'Lagash II',
    'Ebla', 'ED IIIa', 'ED I-II',
    'Uruk IV', 'Linear Elamite',
    'Harappan'
    
}
MID_LATE_BRONZE = {
    'Early Old Babylonian',
    'Old Babylonian', 'Old Assyrian',
    'Middle Babylonian', 'Middle Assyrian',
    'Middle Elamite', 'Middle Hittite'
}
IRON = {
    'Neo-Babylonian', 'Neo-Assyrian',
    'Achaemenid', 'Hellenistic',
    'Neo-Elamite'
}
ERA_MAP = {
    **{K: 'EB' for K in EARLY_BRONZE},
    **{K: 'MLB' for K in MID_LATE_BRONZE},
    **{K: 'I' for K in IRON},
}

In [29]:
def explain(period):
    return f'{period} ({ERA_MAP.get(period, "?")})'

In [30]:
idx2period = {v: k for k, v in TabletPeriodDataset.PERIOD_INDICES.items()}
idx2period[0] = 'other'

In [31]:
# let's just use classes with support >=10, everything else goes to 0: other
COMMON_LABELS = list({k for k, v in Counter(y_true).items() if v >= 10})
print(f'Common labels: ({len(COMMON_LABELS)})')
[(i, explain(idx2period[i])) for i in COMMON_LABELS]

Common labels: (11)


[(1, 'Ur III (EB)'),
 (2, 'Neo-Assyrian (I)'),
 (3, 'Old Babylonian (MLB)'),
 (4, 'Middle Babylonian (MLB)'),
 (5, 'Neo-Babylonian (I)'),
 (6, 'Old Akkadian (EB)'),
 (7, 'Achaemenid (I)'),
 (8, 'Early Old Babylonian (MLB)'),
 (9, 'ED IIIb (EB)'),
 (10, 'Middle Assyrian (MLB)'),
 (11, 'Old Assyrian (MLB)')]

In [32]:
y_true_c = y_true.copy()
y_true_c[~np.isin(y_true, COMMON_LABELS)] = 0
print((~np.isin(y_true, COMMON_LABELS)).mean(), 'changed to "other"')

0.021363173957273652 changed to "other"


In [33]:
y_pred_c = y_pred.copy()
y_pred_c[~np.isin(y_pred, COMMON_LABELS)] = 0
print((~np.isin(y_pred, COMMON_LABELS)).mean(), 'changed to "other"')

0.015259409969481181 changed to "other"


In [34]:
indices_c = list(set(y_true_c) | set(y_pred_c))
print(len(indices_c))
print(indices_c)
PERIOD_LABELS_C = [explain(idx2period[i]) for i in indices_c]
print(PERIOD_LABELS_C)

12
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
['other (?)', 'Ur III (EB)', 'Neo-Assyrian (I)', 'Old Babylonian (MLB)', 'Middle Babylonian (MLB)', 'Neo-Babylonian (I)', 'Old Akkadian (EB)', 'Achaemenid (I)', 'Early Old Babylonian (MLB)', 'ED IIIb (EB)', 'Middle Assyrian (MLB)', 'Old Assyrian (MLB)']


In [35]:
print(classification_report(y_true_c, y_pred_c, target_names=PERIOD_LABELS_C))

                            precision    recall  f1-score   support

                 other (?)       0.40      0.29      0.33        21
               Ur III (EB)       0.78      0.86      0.82       309
          Neo-Assyrian (I)       0.92      0.94      0.93       250
      Old Babylonian (MLB)       0.52      0.72      0.60       179
   Middle Babylonian (MLB)       0.55      0.32      0.40        56
        Neo-Babylonian (I)       0.49      0.41      0.44        49
         Old Akkadian (EB)       0.50      0.32      0.39        22
            Achaemenid (I)       0.60      0.16      0.25        19
Early Old Babylonian (MLB)       0.60      0.10      0.18        29
              ED IIIb (EB)       0.83      0.83      0.83        12
     Middle Assyrian (MLB)       0.33      0.07      0.11        15
        Old Assyrian (MLB)       0.85      0.50      0.63        22

                  accuracy                           0.72       983
                 macro avg       0.61      0.4