In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [1]:
from collections import Counter
from itertools import islice

import pandas as pd
import scipy
import torch
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader
from tqdm import tqdm

from era_data import TabletPeriodDataset
from era_model import SimpleCNN

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Hyperparameters

In [8]:
IMG_DIR = 'output/images_preprocessed'
VERSION_NAME = 'period_clf_bs16_lr5e-05_5epochs-vanillaCNN-94936_samples-preprocessed-March28_1000test'
CKPT_FN = glob(f'lightning_logs/{VERSION_NAME}/checkpoints/*')[0]
CKPT_FN

# Load data and model

In [None]:
train_ids = pd.read_csv(f'output/clf_ids/period-train-{VERSION_NAME}.csv', header=None)[0].astype(str)
test_ids = pd.read_csv(f'output/clf_ids/period-test-{VERSION_NAME}.csv', header=None)[0].astype(str)

In [None]:
ds_train = TabletPeriodDataset(IDS=train_ids, IMG_DIR=IMG_DIR)
ds_test = TabletPeriodDataset(IDS=test_ids, IMG_DIR=IMG_DIR)

In [None]:
def collate_fn(batch):
    
    data = torch.stack([torch.from_numpy(sample[0]) for sample in batch])
    labels = torch.tensor([sample[1] for sample in batch])

    return data, labels

In [37]:
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False, num_workers=4)

In [39]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES) + 2
num_classes

24

In [None]:
%time model = SimpleCNN.load_from_checkpoint(CKPT_FN, num_classes=num_classes)

In [45]:
model.to(device);

In [None]:
model.eval();

# Calculate Predictions

In [46]:
def dl2data(dl, MAX_N=None, device='cuda'):
    logits = []
    y_true = []  # This will hold the period indices
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        gen = tqdm(islice(dl, MAX_N), total=(MAX_N if MAX_N is not None else len(dl)))
        for img, period_index in gen:
            try:
                y_true.append(period_index.cpu().numpy())  # Append period indices
                logits.append(model(img.to(device)).cpu().numpy())  # Process the image through the model
            except Exception as e:
                print(f"Error processing batch: {e}")
    y_true = np.hstack(y_true)
    logits = np.vstack(logits)

    return logits, y_true

In [47]:
logits, y_true = dl2data(dl_test)


  0%|          | 0/32 [00:00<?, ?it/s][A
  3%|▎         | 1/32 [00:00<00:04,  6.84it/s][A
 16%|█▌        | 5/32 [00:00<00:02,  9.10it/s][A
 28%|██▊       | 9/32 [00:01<00:02,  8.97it/s][A
 41%|████      | 13/32 [00:01<00:01, 12.36it/s][A
 50%|█████     | 16/32 [00:01<00:01, 14.32it/s][A
 56%|█████▋    | 18/32 [00:01<00:01, 11.82it/s][A
 62%|██████▎   | 20/32 [00:01<00:00, 12.52it/s][A
 75%|███████▌  | 24/32 [00:02<00:00, 11.37it/s][A
100%|██████████| 32/32 [00:02<00:00, 12.87it/s][A


In [48]:
y_true.shape, logits.shape

((500,), (500, 24))

In [49]:
y_pred = logits.argmax(axis=-1)
y_prob = scipy.special.softmax(logits, axis=-1)

In [50]:
y_pred.shape, y_prob.shape

((500,), (500, 24))

In [51]:
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         3
           1       0.44      0.78      0.56       140
           2       0.68      0.70      0.69       109
           3       0.35      0.40      0.37        91
           4       0.00      0.00      0.00        35
           5       0.25      0.04      0.07        23
           6       0.50      0.05      0.10        19
           7       0.00      0.00      0.00        10
           8       1.00      0.11      0.20         9
           9       0.00      0.00      0.00        15
          10       0.00      0.00      0.00         8
          11       0.00      0.00      0.00        12
          12       0.00      0.00      0.00         7
          13       0.50      0.17      0.25         6
          14       1.00      0.20      0.33         5
          15       0.00      0.00      0.00         3
          16       0.00      0.00      0.00         2
          17       0.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Check Metrics

In [52]:
EARLY_BRONZE = {
    'Old Akkadian', 'Ur III',
    'ED IIIb', 'Uruk III',
    'Proto-Elamite', 'Lagash II',
    'Ebla', 'ED IIIa', 'ED I-II',
    'Uruk IV', 'Linear Elamite',
    'Harappan'
    
}
MID_LATE_BRONZE = {
    'Early Old Babylonian',
    'Old Babylonian', 'Old Assyrian',
    'Middle Babylonian', 'Middle Assyrian',
    'Middle Elamite', 'Middle Hittite'
}
IRON = {
    'Neo-Babylonian', 'Neo-Assyrian',
    'Achaemenid', 'Hellenistic',
    'Neo-Elamite'
}
ERA_MAP = {
    **{K: 'EB' for K in EARLY_BRONZE},
    **{K: 'MLB' for K in MID_LATE_BRONZE},
    **{K: 'I' for K in IRON},
}

In [53]:
def explain(period):
    return f'{period} ({ERA_MAP.get(period, "?")})'

In [54]:
idx2period = {v: k for k, v in TabletPeriodDataset.PERIOD_INDICES.items()}
idx2period[0] = 'other'

In [55]:
# let's just use classes with support >=10, everything else goes to 0: other
COMMON_LABELS = list({k for k, v in Counter(y_true).items() if v >= 10})
print(f'Common labels: ({len(COMMON_LABELS)})')
[(i, explain(idx2period[i])) for i in COMMON_LABELS]

Common labels: (9)


[(1, 'Ur III (EB)'),
 (2, 'Neo-Assyrian (I)'),
 (3, 'Old Babylonian (MLB)'),
 (4, 'Middle Babylonian (MLB)'),
 (5, 'Neo-Babylonian (I)'),
 (6, 'Old Akkadian (EB)'),
 (7, 'Achaemenid (I)'),
 (9, 'ED IIIb (EB)'),
 (11, 'Old Assyrian (MLB)')]

In [56]:
y_true_c = y_true.copy()
y_true_c[~np.isin(y_true, COMMON_LABELS)] = 0
print((~np.isin(y_true, COMMON_LABELS)).mean(), 'changed to "other"')

0.092 changed to "other"


In [57]:
y_pred_c = y_pred.copy()
y_pred_c[~np.isin(y_pred, COMMON_LABELS)] = 0
print((~np.isin(y_pred, COMMON_LABELS)).mean(), 'changed to "other"')

0.028 changed to "other"


In [58]:
indices_c = list(set(y_true_c) | set(y_pred_c))
print(len(indices_c))
print(indices_c)
PERIOD_LABELS_C = [explain(idx2period[i]) for i in indices_c]
print(PERIOD_LABELS_C)

10
[0, 1, 2, 3, 4, 5, 6, 7, 9, 11]
['other (?)', 'Ur III (EB)', 'Neo-Assyrian (I)', 'Old Babylonian (MLB)', 'Middle Babylonian (MLB)', 'Neo-Babylonian (I)', 'Old Akkadian (EB)', 'Achaemenid (I)', 'ED IIIb (EB)', 'Old Assyrian (MLB)']


In [59]:
print(classification_report(y_true_c, y_pred_c, target_names=PERIOD_LABELS_C))

                         precision    recall  f1-score   support

              other (?)       0.29      0.09      0.13        46
            Ur III (EB)       0.44      0.78      0.56       140
       Neo-Assyrian (I)       0.68      0.70      0.69       109
   Old Babylonian (MLB)       0.35      0.40      0.37        91
Middle Babylonian (MLB)       0.00      0.00      0.00        35
     Neo-Babylonian (I)       0.25      0.04      0.07        23
      Old Akkadian (EB)       0.50      0.05      0.10        19
         Achaemenid (I)       0.00      0.00      0.00        10
           ED IIIb (EB)       0.00      0.00      0.00        15
     Old Assyrian (MLB)       0.00      0.00      0.00        12

               accuracy                           0.45       500
              macro avg       0.25      0.21      0.19       500
           weighted avg       0.39      0.45      0.39       500

