In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [50]:
import torch

In [51]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Hyperparameters

In [52]:
LR = 1e-3
EPOCHS = 3
BATCH_SIZE = 16
SUFFIX = '-vanillaCNN'

# Load data

In [53]:
# IMG_DIR = 'output/images'
RUN_NAME_SUFFIX = '-preprocessed2' # ''
IMG_DIR = 'output/images_preprocessed'

In [54]:
#! du -h {IMG_DIR}

In [55]:
from era_data import TabletPeriodDataset, get_IDS
from collections import Counter
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import pandas as pd

In [56]:
IDS = get_IDS(IMG_DIR=IMG_DIR)
len(IDS)

97640

In [57]:
VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_{EPOCHS}epochs{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}'
VERSION_NAME

'period_clf_bs16_lr0.001_3epochs-vanillaCNN-97640_samples-preprocessed2'

In [58]:
train_ids, test_ids = train_test_split(IDS, test_size=500, random_state=0)
len(train_ids), len(test_ids)

(97140, 500)

In [59]:
ds_train = TabletPeriodDataset(IDS=train_ids, IMG_DIR=IMG_DIR)
ds_test = TabletPeriodDataset(IDS=test_ids)

Filtering 97640 IDS down to provided 97140...
Filtering 97640 IDS down to provided 500...


In [60]:
import numpy as np
from PIL import Image

def collate_fn(batch):
    unsqueezed_data = []
    labels = []

    for sample in batch:
        
        if isinstance(sample[0], np.ndarray):
            img = Image.fromarray(sample[0])
        else:
            img = sample[0] 

        # Resize the image
        img_resized = img.resize((178, 218), Image.NEAREST)

        # Convert the resized PIL image to a tensor and unsqueeze to add a channel dimension
        img_tensor = torch.unsqueeze(torch.tensor(np.array(img_resized), dtype=torch.float32), 0)

        unsqueezed_data.append(img_tensor)
        labels.append(sample[1])

    # Stack all the image tensors and labels together
    data_tensor = torch.stack(unsqueezed_data, dim=0)
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    return data_tensor, labels_tensor

In [61]:
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE,collate_fn=collate_fn, shuffle=True, num_workers=4)
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False, num_workers=4)

In [62]:
# save model IDs so we can keep track of what data it was trained on
pd.Series(train_ids).to_csv(f'output/clf_ids/period-train-{VERSION_NAME}.csv', index=False, header=None)
pd.Series(test_ids).to_csv(f'output/clf_ids/period-test-{VERSION_NAME}.csv', index=False, header=None)

In [63]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES) + 2
num_classes

24

In [64]:
import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

class SimpleCNN(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.conv1 = nn.Conv2d(1, 32, 3, 1) 
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(2383104, 128)  
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        return {'test_logits': logits, 'test_y': y}
    
    def test_epoch_end(self, outputs):
        # Concatenate all logits and labels gathered from each test_step
        logits = torch.cat([x['test_logits'] for x in outputs], dim=0)
        labels = torch.cat([x['test_y'] for x in outputs], dim=0)
        self.log('test_logits', logits)
        self.log('test_labels', labels)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

In [65]:
model = SimpleCNN(num_classes=num_classes)

In [66]:
logger = pl.loggers.TensorBoardLogger(
    save_dir='.',
    name='lightning_logs',
    version=VERSION_NAME
)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator='gpu',
    devices='auto',
    callbacks=[lr_monitor],
    logger=logger
)


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [67]:
print('Logs to:', VERSION_NAME)

Logs to: period_clf_bs16_lr0.001_3epochs-vanillaCNN-97640_samples-preprocessed2


In [68]:
trainer.fit(model, dl_train, dl_test)

  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | conv1 | Conv2d | 320   
1 | conv2 | Conv2d | 18.5 K
2 | fc1   | Linear | 305 M 
3 | fc2   | Linear | 3.1 K 
---------------------------------
305 M     Trainable params
0         Non-trainable params
305 M     Total params
1,220.237 Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [78]:
model.to(device);

In [81]:
from tqdm import tqdm
from itertools import islice


def dl2data(dl, MAX_N=None, device='cuda'):
    logits = []
    y_true = []  # This will hold the period indices
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        gen = tqdm(islice(dl, MAX_N), total=(MAX_N if MAX_N is not None else len(dl)))
        for img, period_index in gen:
            try:
                y_true.append(period_index.cpu().numpy())  # Append period indices
                logits.append(model(img.to(device)).cpu().numpy())  # Process the image through the model
            except Exception as e:
                print(f"Error processing batch: {e}")
    y_true = np.hstack(y_true)
    logits = np.vstack(logits)

    return logits, y_true

In [82]:
logits, y_true = dl2data(dl_test)


  0%|          | 0/32 [00:00<?, ?it/s][A
  3%|▎         | 1/32 [00:00<00:04,  7.48it/s][A
 28%|██▊       | 9/32 [00:00<00:00, 42.19it/s][A
 50%|█████     | 16/32 [00:00<00:00, 48.30it/s][A
100%|██████████| 32/32 [00:00<00:00, 38.35it/s][A


In [83]:
y_true.shape, logits.shape

((500,), (500, 24))

In [85]:
import scipy 

y_pred = logits.argmax(axis=-1)
y_prob = scipy.special.softmax(logits, axis=-1)

In [86]:
y_pred.shape, y_prob.shape

((500,), (500, 24))

In [87]:
(y_pred == y_true).mean()

0.634

In [89]:
from sklearn.metrics import classification_report

print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         3
           1       0.80      0.81      0.80       140
           2       0.84      0.80      0.82       109
           3       0.51      0.57      0.54        91
           4       0.40      0.29      0.33        35
           5       0.22      0.35      0.27        23
           6       0.69      0.47      0.56        19
           7       0.31      0.40      0.35        10
           8       0.50      0.11      0.18         9
           9       0.80      0.53      0.64        15
          10       0.15      0.38      0.21         8
          11       0.62      0.67      0.64        12
          12       0.75      0.43      0.55         7
          13       0.80      0.67      0.73         6
          14       0.60      0.60      0.60         5
          15       0.25      0.33      0.29         3
          16       1.00      0.50      0.67         2
          17       1.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [96]:
EARLY_BRONZE = {
    'Old Akkadian', 'Ur III',
    'ED IIIb', 'Uruk III',
    'Proto-Elamite', 'Lagash II',
    'Ebla', 'ED IIIa', 'ED I-II',
    'Uruk IV', 'Linear Elamite',
    'Harappan'
    
}
MID_LATE_BRONZE = {
    'Early Old Babylonian',
    'Old Babylonian', 'Old Assyrian',
    'Middle Babylonian', 'Middle Assyrian',
    'Middle Elamite', 'Middle Hittite'
}
IRON = {
    'Neo-Babylonian', 'Neo-Assyrian',
    'Achaemenid', 'Hellenistic',
    'Neo-Elamite'
}
ERA_MAP = {
    **{K: 'EB' for K in EARLY_BRONZE},
    **{K: 'MLB' for K in MID_LATE_BRONZE},
    **{K: 'I' for K in IRON},
}

In [97]:
def explain(period):
    return f'{period} ({ERA_MAP.get(period, "?")})'

In [98]:
idx2period = {v: k for k, v in TabletPeriodDataset.PERIOD_INDICES.items()}
idx2period[0] = 'other'

In [99]:
# let's just use classes with support >=10, everything else goes to 0: other
COMMON_LABELS = list({k for k, v in Counter(y_true).items() if v >= 10})
print(f'Common labels: ({len(COMMON_LABELS)})')
[(i, explain(idx2period[i])) for i in COMMON_LABELS]

Common labels: (9)


[(1, 'Ur III (EB)'),
 (2, 'Neo-Assyrian (I)'),
 (3, 'Old Babylonian (MLB)'),
 (4, 'Middle Babylonian (MLB)'),
 (5, 'Neo-Babylonian (I)'),
 (6, 'Old Akkadian (EB)'),
 (7, 'Achaemenid (I)'),
 (9, 'ED IIIb (EB)'),
 (11, 'Old Assyrian (MLB)')]

In [100]:
y_true_c = y_true.copy()
y_true_c[~np.isin(y_true, COMMON_LABELS)] = 0
print((~np.isin(y_true, COMMON_LABELS)).mean(), 'changed to "other"')

0.092 changed to "other"


In [101]:
y_pred_c = y_pred.copy()
y_pred_c[~np.isin(y_pred, COMMON_LABELS)] = 0
print((~np.isin(y_pred, COMMON_LABELS)).mean(), 'changed to "other"')

0.086 changed to "other"


In [102]:
indices_c = list(set(y_true_c) | set(y_pred_c))
print(len(indices_c))
print(indices_c)
PERIOD_LABELS_C = [explain(idx2period[i]) for i in indices_c]
print(PERIOD_LABELS_C)

10
[0, 1, 2, 3, 4, 5, 6, 7, 9, 11]
['other (?)', 'Ur III (EB)', 'Neo-Assyrian (I)', 'Old Babylonian (MLB)', 'Middle Babylonian (MLB)', 'Neo-Babylonian (I)', 'Old Akkadian (EB)', 'Achaemenid (I)', 'ED IIIb (EB)', 'Old Assyrian (MLB)']


In [103]:
print(classification_report(y_true_c, y_pred_c, target_names=PERIOD_LABELS_C))

                         precision    recall  f1-score   support

              other (?)       0.53      0.50      0.52        46
            Ur III (EB)       0.80      0.81      0.80       140
       Neo-Assyrian (I)       0.84      0.80      0.82       109
   Old Babylonian (MLB)       0.51      0.57      0.54        91
Middle Babylonian (MLB)       0.40      0.29      0.33        35
     Neo-Babylonian (I)       0.22      0.35      0.27        23
      Old Akkadian (EB)       0.69      0.47      0.56        19
         Achaemenid (I)       0.31      0.40      0.35        10
           ED IIIb (EB)       0.80      0.53      0.64        15
     Old Assyrian (MLB)       0.62      0.67      0.64        12

               accuracy                           0.64       500
              macro avg       0.57      0.54      0.55       500
           weighted avg       0.66      0.64      0.65       500

