In [29]:
import torch
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool
from era_data import TabletPeriodDataset, get_IDS
from VAE_model_tablets import VAE
import numpy as np
from PIL import Image
from torchvision import transforms

In [16]:
# IMG_DIR = 'output/images'
RUN_NAME_SUFFIX = '-masked' # ''
IMG_DIR = 'output/images_preprocessed'
LR = 5e-5
EPOCHS = 50
BATCH_SIZE = 16
SUFFIX = '-resnet50'

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [7]:
IDS = get_IDS(IMG_DIR=IMG_DIR)
len(IDS)

97640

In [18]:
VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_{EPOCHS}epochs{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}_blurvae-conv-Sept20'

VERSION_NAME

'period_clf_bs16_lr5e-05_50epochs-resnet50-97640_samples-masked_blurvae-conv-Sept20'

In [19]:
VERNAME = 'period_clf_bs16_lr5e-05_15epochs-resnet50-97640_samples-masked-Sept_19_blur-early_stopping-1'

train_ids = pd.read_csv(f'output/clf_ids/period-train-{VERNAME}.csv', header=None)[0].astype(str)
test_ids = pd.read_csv(f'output/clf_ids/period-test-{VERNAME}.csv', header=None)[0].astype(str)

In [10]:
ds_train = TabletPeriodDataset(IDS=train_ids, IMG_DIR=IMG_DIR, mask=True)
ds_test = TabletPeriodDataset(IDS=test_ids, IMG_DIR=IMG_DIR, mask=True)

Filtering 97640 IDS down to provided 97140...
Filtering 97640 IDS down to provided 500...


In [20]:
chekpoint_path = f'lightning_logs/{VERSION_NAME}/checkpoints/epoch=39-step=535208.ckpt'

vae_model = VAE.load_from_checkpoint(chekpoint_path,image_channels=1,z_dim=16, lr =1e-5 )

In [33]:
# Function to preprocess an image
def preprocess_image(img, label, genre):
    img = np.array(Image.fromarray(img).resize((178, 218), Image.NEAREST))  # Resize the image
    return img, label, genre

def process_batch(batch, vae_model, device):
    # Preprocess each image in the batch
    preprocessed_batch = [preprocess_image(img, label, genre) for img, label, genre in batch]
    imgs, periods, genres = zip(*preprocessed_batch)

    # Convert list of images to torch tensor and move to the device
    img_tensors = torch.stack([transforms.ToTensor()(img) for img in imgs]).to(device)

    with torch.no_grad():
        encodings = vae_model.representation(img_tensors)
    encodings = encodings.cpu().numpy()

    return encodings, periods, genres

In [None]:
def get_encodings_and_labels(dataset, vae_model, device='cuda', batch_size=32):
    vae_model.to(device)
    vae_model.eval()

    all_encodings = []
    all_periods = []
    all_genres = []

    num_batches = (len(dataset) + batch_size - 1) // batch_size

    for i in tqdm(range(num_batches)):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, len(dataset))
        batch = [dataset[j] for j in range(start_idx, end_idx)]
        encodings, periods, genres = process_batch(batch, vae_model, device)
        all_encodings.extend(encodings)
        all_periods.extend(periods)
        all_genres.extend(genres)

    return np.array(all_encodings), np.array(all_periods), np.array(all_genres)

In [34]:
# Main execution
encodings, periods, genres = get_encodings_and_labels(ds_train, vae_model)

# Create dataframe
df = pd.DataFrame(encodings, columns=[f"X{i}" for i in range(1, encodings.shape[1] + 1)])
df['Period'] = periods
df['Genre'] = genres

100%|██████████| 2972/2972 [14:05<00:00,  3.51it/s]


In [35]:
PERIOD_INDICES = {

    0: 'other',
    1: 'Ur III',
    2: 'Neo-Assyrian',
    3: 'Old Babylonian',
    4: 'Middle Babylonian',
    5: 'Neo-Babylonian',
    6: 'Old Akkadian',
    7: 'Achaemenid',
    8: 'Early Old Babylonian',
    9: 'ED IIIb',
    10: 'Middle Assyrian',
    11: 'Old Assyrian',
    12: 'Uruk III',
    13: 'Proto-Elamite',
    14: 'Lagash II',
    15: 'Ebla',
    16: 'ED IIIa',
    17: 'Hellenistic',
    18: 'ED I-II',
    19: 'Middle Elamite',
    20: 'Middle Hittite',
    21: 'Uruk IV'
}

In [36]:
GENRE_INDICES = {
        
        1: 'Administrative',
        2: 'Letter',
        3: 'Legal',
        4: 'Royal/Monumental',
        5: 'Literary',
        6: 'Lexical',
        7: 'Omen',
        8: 'uncertain',
        9: 'School',
        10: 'Mathematical',
        11: 'Prayer/Incantation',
        12: 'Scientific',
        13: 'Ritual',
        14: 'fake (modern)',
        15: 'Astronomical',
        16: 'Private/Votive',
    }

In [37]:
df['Period_Name'] = df['Period'].map(PERIOD_INDICES)

In [38]:
df['Genre_Name'] = df['Genre'].map(GENRE_INDICES)

In [39]:
df

Unnamed: 0,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13,X14,X15,X16,Period,Genre,Period_Name,Genre_Name
0,1.405217,2.028828,2.021450,3.322014,-2.429911,-0.113327,-0.513614,1.218468,-1.794519,-0.267109,0.106228,1.258922,1.776628,-0.477941,1.245738,-0.668985,1,1,Ur III,Administrative
1,1.941703,1.681338,2.013874,1.214052,-2.624053,-0.565685,0.587282,1.955094,-2.142180,1.434290,-0.085407,0.772144,1.782284,-1.149729,1.159433,-1.051923,1,1,Ur III,Administrative
2,1.545616,0.739789,0.814151,3.536586,-0.738648,-1.797301,0.435611,0.790132,-0.631153,-1.207778,2.223934,0.260261,2.234797,-0.841837,2.341154,-1.226452,1,1,Ur III,Administrative
3,1.471630,1.856850,0.394832,2.731543,-1.828696,-1.380188,0.304390,0.394333,0.091745,-0.505310,-0.243529,-0.879123,1.616585,-0.052371,1.296605,-1.856441,1,1,Ur III,Administrative
4,0.682864,1.135242,-1.071754,-1.261702,-2.055277,-0.845992,-1.013888,0.358597,-1.522938,0.556894,0.668988,-1.384990,1.640509,-2.107970,0.021280,-3.175014,1,1,Ur III,Administrative
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95083,-1.387077,-0.981190,1.870805,4.099249,0.118536,-0.470650,2.735257,-0.293981,0.305332,6.137814,4.190940,0.631822,3.759254,1.938874,-1.443506,1.933406,1,1,Ur III,Administrative
95084,-0.956240,2.186746,3.336722,4.295513,3.944624,-2.227798,-0.343551,-2.244611,1.658638,4.263387,3.765898,3.242968,1.341011,1.107115,0.112043,2.474764,1,4,Ur III,Royal/Monumental
95085,-0.981044,-2.253840,-1.110260,-0.189166,-1.064016,-1.781169,2.626445,1.059962,2.313858,-0.686680,-3.115945,-0.806476,-0.008555,-3.264682,0.826424,-0.741377,3,9,Old Babylonian,School
95086,0.717042,0.621707,-0.251391,1.129647,-0.143000,-0.929157,0.736152,0.734872,-1.055284,-2.549059,-1.171392,-1.852778,3.704706,-2.079251,-1.749792,-0.040010,3,9,Old Babylonian,School


In [41]:
df.to_csv("vae_encodings_and_data/vae_encoding_df_sept30_train.csv", index=False)

I will now repeat the process for the test set:

In [42]:
# Main execution
encodings, periods, genres = get_encodings_and_labels(ds_test, vae_model)

# Create dataframe
df_test = pd.DataFrame(encodings, columns=[f"X{i}" for i in range(1, encodings.shape[1] + 1)])
df_test['Period'] = periods
df_test['Genre'] = genres

100%|██████████| 16/16 [00:06<00:00,  2.60it/s]


In [43]:
df_test['Period_Name'] = df_test['Period'].map(PERIOD_INDICES)
df_test['Genre_Name'] = df_test['Genre'].map(GENRE_INDICES)

In [44]:
df_test.to_csv("vae_encodings_and_data/vae_encoding_df_sept30_test.csv", index=False)