In [1]:
import torch
import pandas as pd
from tqdm import tqdm
from era_data import TabletPeriodDataset, get_IDS
from VAE_model_tablets_class import VAE
import numpy as np
from PIL import Image
from torchvision import transforms

In [2]:
# IMG_DIR = 'output/images'
RUN_NAME_SUFFIX = '-masked_w_classification_loss' # ''
IMG_DIR = 'output/images_preprocessed'
LR = 5e-5
EPOCHS = 30
BATCH_SIZE = 16
SUFFIX = '-resnet50'
DATE = 'Oct2-v3'

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
IDS = get_IDS(IMG_DIR=IMG_DIR)
len(IDS)

97640

In [3]:
VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_{EPOCHS}epochs{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}_blurvae-conv-{DATE}'
VERSION_NAME

NameError: name 'IDS' is not defined

In [6]:
VERNAME = 'period_clf_bs16_lr5e-05_15epochs-resnet50-97640_samples-masked-Sept_19_blur-early_stopping-1'

train_ids = pd.read_csv(f'output/clf_ids/period-train-{VERNAME}.csv', header=None)[0].astype(str)
test_ids = pd.read_csv(f'output/clf_ids/period-test-{VERNAME}.csv', header=None)[0].astype(str)

In [7]:
ds_train = TabletPeriodDataset(IDS=train_ids, IMG_DIR=IMG_DIR, mask=True)
ds_test = TabletPeriodDataset(IDS=test_ids, IMG_DIR=IMG_DIR, mask=True)

Filtering 97640 IDS down to provided 97140...
Filtering 97640 IDS down to provided 500...


In [8]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
num_classes

22

In [9]:
class_weights = torch.load("data/class_weights_period.pt")

In [10]:
chekpoint_path = f'lightning_logs/{VERSION_NAME}/checkpoints/epoch=29-step=407516.ckpt'
vae_model = VAE.load_from_checkpoint(chekpoint_path,image_channels=1,z_dim=16, lr =1e-5, use_classification_loss=True, num_classes=num_classes,
            loss_type="weighted", class_weights=class_weights, device = device)

  self.class_weights = torch.tensor(class_weights).to(device)


In [11]:
# Function to preprocess an image
def preprocess_image(img, label, genre, prov):
    img = np.array(Image.fromarray(img).resize((178, 218), Image.NEAREST))  # Resize the image
    return img, label, genre, prov

def process_batch(batch, vae_model, device):
    # Preprocess each image in the batch
    preprocessed_batch = [preprocess_image(img, label, genre, prov) for img, label, genre, prov in batch]
    imgs, periods, genres, provs = zip(*preprocessed_batch)

    # Convert list of images to torch tensor and move to the device
    img_tensors = torch.stack([transforms.ToTensor()(img) for img in imgs]).to(device)

    with torch.no_grad():
        encodings = vae_model.representation(img_tensors)
    encodings = encodings.cpu().numpy()

    return encodings, periods, genres, provs

In [12]:
def get_encodings_and_labels(dataset, vae_model, device='cuda', batch_size=32):
    vae_model.to(device)
    vae_model.eval()

    all_encodings = []
    all_periods = []
    all_genres = []
    all_provs = []
    
    num_batches = (len(dataset) + batch_size - 1) // batch_size

    for i in tqdm(range(num_batches)):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, len(dataset))
        batch = [dataset[j] for j in range(start_idx, end_idx)]
        encodings, periods, genres, provs = process_batch(batch, vae_model, device)
        all_encodings.extend(encodings)
        all_periods.extend(periods)
        all_genres.extend(genres)
        all_provs.extend(provs)

    return np.array(all_encodings), np.array(all_periods), np.array(all_genres), np.array(all_provs)

In [13]:
# Main execution
encodings, periods, genres, provs = get_encodings_and_labels(ds_train, vae_model)

100%|██████████| 2972/2972 [12:20<00:00,  4.01it/s]


In [14]:
# Create dataframe
df = pd.DataFrame(encodings, columns=[f"X{i}" for i in range(1, encodings.shape[1] + 1)])
df['Period'] = periods
df['Genre'] = genres
df['Provenience'] = provs

In [16]:
PERIOD_INDICES = {

    0: 'other',
    1: 'Ur III',
    2: 'Neo-Assyrian',
    3: 'Old Babylonian',
    4: 'Middle Babylonian',
    5: 'Neo-Babylonian',
    6: 'Old Akkadian',
    7: 'Achaemenid',
    8: 'Early Old Babylonian',
    9: 'ED IIIb',
    10: 'Middle Assyrian',
    11: 'Old Assyrian',
    12: 'Uruk III',
    13: 'Proto-Elamite',
    14: 'Lagash II',
    15: 'Ebla',
    16: 'ED IIIa',
    17: 'Hellenistic',
    18: 'ED I-II',
    19: 'Middle Elamite',
    20: 'Middle Hittite',
    21: 'Uruk IV'
}

In [17]:
PROVENIENCE_INDICES = {
    1: 'Nineveh',
    2: 'Nippur',
    3: 'unknown',
    4: 'Umma',
    5: 'Puzris-Dagan',
    6: 'Girsu',
    7: 'Ur',
    8: 'Uruk',
    9: 'Kanesh',
    10: 'Assur',
    11: 'Adab',
    12: 'Garsana',
    13: 'Gasur/Nuzi',
    14: 'Susa',
    15: 'Sippar-Yahrurum',
    16: 'Larsa',
    17: 'Nerebtum',
    18: 'mod. Babylonia',
    19: 'Parsa',
    20: 'Kish',
    21: 'Kalhu',
    22: 'Tuttul',
    23: 'Suruppak',
    24: 'Babili',
    25: 'Ebla',
    26: 'mod. Beydar',
    27: 'Akhetaten',
    28: 'Esnunna',
    29: 'Borsippa',
    30: 'Kar-Tukulti-Ninurta',
    31: 'mod. Jemdet Nasr',
    32: 'mod. northern Babylonia',
    33: 'Alalakh',
    34: 'Hattusa',
    35: 'Isin',
    36: 'Elbonia',
    37: 'Sibaniba',
    38: 'Tutub',
    39: 'Pi-Kasi',
    40: 'Irisagrig',
    41: 'Ansan',
    42: 'Dilbat',
    43: 'Zabalam',
    44: 'mod. Mugdan/ Umm al-Jir',
    45: 'Marad',
    46: 'Eridu',
    47: 'Seleucia',
    48: 'mod. Abu Halawa',
    49: 'Dur-Untas',
    50: 'Nagar',
    51: 'Lagaba',
    52: 'Asnakkum',
    53: 'Dur-Kurigalzu',
    54: 'mod. Tell Sabaa',
    55: 'mod. Abu Jawan',
    56: 'mod. Tell Fakhariyah',
    57: 'Dur-Abi-esuh',
    58: 'Ugarit',
    59: 'mod. Diqdiqqah',
    60: 'Tarbisu',
    61: 'Lagash',
    62: 'Kisurra',
    63: 'Elammu',
    64: 'Du-Enlila',
    65: 'Kutha',
    66: 'mod. Umm el-Hafriyat',
    67: 'Dur-Sarrukin',
    68: 'Bad-Tibira',
    69: 'Bit-zerija',
    70: 'Kilizu',
    71: 'mod. Pasargadae',
    72: 'Abdju',
    73: 'Surmes',
    74: 'mod. Qatibat',
    75: 'Tigunanum',
    76: 'mod. Tell al-Lahm',
    77: 'mod. Mesopotamia',
    78: 'Subat-Enlil',
    79: 'mod. Konar Sandal',
    80: 'Gissi',
    81: 'Agamatanu',
    82: 'Aqa',
    83: 'Kapri-sa-naqidati',
    84: 'Esura',
    85: 'Nahalla',
    86: 'Bit-Sahtu',
    87: 'mod. Sepphoris',
    88: 'Dusabar',
    89: 'mod. Tell Sifr',
    90: 'Nasir',
    91: 'Kumu',
    92: 'Kazallu',
    93: 'Kapru',
    94: 'Hurruba',
    95: 'mod. Deh-e-no, Iran',
    96: "mod. Za'aleh",
    97: 'mod. Tepe Farukhabad',
    98: 'Hursagkalama',
    99: 'Carchemish',
    100: 'mod. Ben Shemen, Israel',
    101: 'Kutalla',
     102: 'Der',
    103: 'Imgur-Enlil',
    104: 'mod. Hillah',
    105: 'mod. Uhudu',
    106: 'mod. Mahmudiyah',
    107: 'Terqa',
    108: 'Arrapha',
    109: 'mod. Tell en-Nasbeh',
    110: 'mod. Kalah Shergat',
    111: 'Kar-Nabu',
    112: 'Harran',
    113: 'mod. Til-Buri',
    114: 'Shuruppak',
    115: 'mod. Abu Salabikh',
    116: "Ma'allanate",
    117: 'Kar-Mullissu',
    118: 'mod. Naqs-i-Rustam'
}


In [18]:
GENRE_INDICES = {
        
        1: 'Administrative',
        2: 'Letter',
        3: 'Legal',
        4: 'Royal/Monumental',
        5: 'Literary',
        6: 'Lexical',
        7: 'Omen',
        8: 'uncertain',
        9: 'School',
        10: 'Mathematical',
        11: 'Prayer/Incantation',
        12: 'Scientific',
        13: 'Ritual',
        14: 'fake (modern)',
        15: 'Astronomical',
        16: 'Private/Votive',
    }

In [19]:
df['Period_Name'] = df['Period'].map(PERIOD_INDICES)

In [20]:
df['Genre_Name'] = df['Genre'].map(GENRE_INDICES)

In [21]:
df['Provenience_Name'] = df['Provenience'].map(PROVENIENCE_INDICES)

In [24]:
df.to_csv(f"vae_encodings_and_data/vae_encoding_df_{DATE}_w_class_train_w_prov.csv", index=False)

I will now repeat the process for the test set:

In [26]:
# Main execution
encodings, periods, genres, provs = get_encodings_and_labels(ds_test, vae_model)

# Create dataframe
df_test = pd.DataFrame(encodings, columns=[f"X{i}" for i in range(1, encodings.shape[1] + 1)])
df_test['Period'] = periods
df_test['Genre'] = genres
df_test['Provenience'] = provs

100%|██████████| 16/16 [00:03<00:00,  4.30it/s]


In [27]:
df_test['Period_Name'] = df_test['Period'].map(PERIOD_INDICES)
df_test['Genre_Name'] = df_test['Genre'].map(GENRE_INDICES)
df_test['Provenience_Name'] = df_test['Genre'].map(PROVENIENCE_INDICES)

In [28]:
df_test.to_csv(f"vae_encodings_and_data/vae_encoding_df_{DATE}_w_class_test_w_provs.csv", index=False)