In [1]:
import torch
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool
from era_data import TabletPeriodDataset, get_IDS
from VAE_model_tablets_class import VAE
import numpy as np
from PIL import Image
from torchvision import transforms

In [2]:
# IMG_DIR = 'output/images'
RUN_NAME_SUFFIX = '-masked_w_classification_loss' # ''
IMG_DIR = 'output/images_preprocessed'
LR = 5e-5
EPOCHS = 15
BATCH_SIZE = 16
SUFFIX = '-resnet50'
DATE = 'Oct2'

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
IDS = get_IDS(IMG_DIR=IMG_DIR)
len(IDS)

97640

In [5]:
VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_{EPOCHS}epochs{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}_blurvae-conv-{DATE}'
VERSION_NAME

'period_clf_bs16_lr5e-05_15epochs-resnet50-97640_samples-masked_w_classification_loss_blurvae-conv-Oct2'

In [6]:
VERNAME = 'period_clf_bs16_lr5e-05_15epochs-resnet50-97640_samples-masked-Sept_19_blur-early_stopping-1'

train_ids = pd.read_csv(f'output/clf_ids/period-train-{VERNAME}.csv', header=None)[0].astype(str)
test_ids = pd.read_csv(f'output/clf_ids/period-test-{VERNAME}.csv', header=None)[0].astype(str)

In [7]:
ds_train = TabletPeriodDataset(IDS=train_ids, IMG_DIR=IMG_DIR, mask=True)
ds_test = TabletPeriodDataset(IDS=test_ids, IMG_DIR=IMG_DIR, mask=True)

Filtering 97640 IDS down to provided 97140...
Filtering 97640 IDS down to provided 500...


In [12]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
num_classes

22

In [13]:
class_weights = torch.load("data/class_weights_period.pt")

In [29]:
chekpoint_path = f'lightning_logs/{VERSION_NAME}/checkpoints/epoch=14-step=203756.ckpt'
vae_model = VAE.load_from_checkpoint(chekpoint_path,image_channels=1,z_dim=16, lr =1e-5, use_classification_loss=True, num_classes=num_classes,
            loss_type="weighted", class_weights=class_weights, device = device)

  self.class_weights = torch.tensor(class_weights).to(device)


In [15]:
# Function to preprocess an image
def preprocess_image(img, label, genre):
    img = np.array(Image.fromarray(img).resize((178, 218), Image.NEAREST))  # Resize the image
    return img, label, genre

def process_batch(batch, vae_model, device):
    # Preprocess each image in the batch
    preprocessed_batch = [preprocess_image(img, label, genre) for img, label, genre in batch]
    imgs, periods, genres = zip(*preprocessed_batch)

    # Convert list of images to torch tensor and move to the device
    img_tensors = torch.stack([transforms.ToTensor()(img) for img in imgs]).to(device)

    with torch.no_grad():
        encodings = vae_model.representation(img_tensors)
    encodings = encodings.cpu().numpy()

    return encodings, periods, genres

In [16]:
def get_encodings_and_labels(dataset, vae_model, device='cuda', batch_size=32):
    vae_model.to(device)
    vae_model.eval()

    all_encodings = []
    all_periods = []
    all_genres = []

    num_batches = (len(dataset) + batch_size - 1) // batch_size

    for i in tqdm(range(num_batches)):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, len(dataset))
        batch = [dataset[j] for j in range(start_idx, end_idx)]
        encodings, periods, genres = process_batch(batch, vae_model, device)
        all_encodings.extend(encodings)
        all_periods.extend(periods)
        all_genres.extend(genres)

    return np.array(all_encodings), np.array(all_periods), np.array(all_genres)

In [19]:
# Main execution
encodings, periods, genres = get_encodings_and_labels(ds_train, vae_model)

# Create dataframe
df = pd.DataFrame(encodings, columns=[f"X{i}" for i in range(1, encodings.shape[1] + 1)])
df['Period'] = periods
df['Genre'] = genres

100%|██████████| 2972/2972 [09:40<00:00,  5.12it/s]


In [20]:
PERIOD_INDICES = {

    0: 'other',
    1: 'Ur III',
    2: 'Neo-Assyrian',
    3: 'Old Babylonian',
    4: 'Middle Babylonian',
    5: 'Neo-Babylonian',
    6: 'Old Akkadian',
    7: 'Achaemenid',
    8: 'Early Old Babylonian',
    9: 'ED IIIb',
    10: 'Middle Assyrian',
    11: 'Old Assyrian',
    12: 'Uruk III',
    13: 'Proto-Elamite',
    14: 'Lagash II',
    15: 'Ebla',
    16: 'ED IIIa',
    17: 'Hellenistic',
    18: 'ED I-II',
    19: 'Middle Elamite',
    20: 'Middle Hittite',
    21: 'Uruk IV'
}

In [21]:
GENRE_INDICES = {
        
        1: 'Administrative',
        2: 'Letter',
        3: 'Legal',
        4: 'Royal/Monumental',
        5: 'Literary',
        6: 'Lexical',
        7: 'Omen',
        8: 'uncertain',
        9: 'School',
        10: 'Mathematical',
        11: 'Prayer/Incantation',
        12: 'Scientific',
        13: 'Ritual',
        14: 'fake (modern)',
        15: 'Astronomical',
        16: 'Private/Votive',
    }

In [22]:
df['Period_Name'] = df['Period'].map(PERIOD_INDICES)

In [23]:
df['Genre_Name'] = df['Genre'].map(GENRE_INDICES)

In [24]:
df

Unnamed: 0,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13,X14,X15,X16,Period,Genre,Period_Name,Genre_Name
0,4.919055,-0.909482,0.731892,3.075751,0.739380,1.116917,-0.079185,-2.698598,-0.269075,-0.048176,0.498897,-0.863170,0.132350,1.552988,-0.643640,1.021853,1,1,Ur III,Administrative
1,5.572304,0.368163,0.547902,1.861315,2.257494,-0.009318,-1.815033,-2.713016,-0.048724,-0.110121,0.344479,-0.798703,1.087145,0.339530,0.039574,1.694017,1,1,Ur III,Administrative
2,3.653273,-0.613616,0.854104,2.117145,2.047022,1.595965,0.886784,-1.895508,1.556806,0.692500,0.160684,0.265394,-1.112293,0.948919,-1.182204,-1.597994,1,1,Ur III,Administrative
3,2.979478,-0.681644,1.629997,3.078679,2.948597,0.412386,0.453947,-1.124372,-0.059042,1.094117,-0.880540,1.021095,-0.224409,0.850543,-1.316289,0.622071,1,1,Ur III,Administrative
4,2.673495,2.518288,1.001184,0.948618,1.412897,-2.765255,-1.220888,-0.643057,-0.337250,-1.413469,-1.256204,0.544692,1.556024,-0.523407,-2.559521,-0.752193,1,1,Ur III,Administrative
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95083,0.737175,1.758523,5.484700,-1.709509,1.047772,5.191993,-1.451209,-3.498457,2.893486,-0.044255,2.840436,-2.595210,0.089514,-4.235375,4.771692,-2.074991,1,1,Ur III,Administrative
95084,-0.805932,0.645323,2.743832,0.007985,-0.520500,7.281477,0.626895,-5.238637,1.697937,2.884061,-0.349408,-4.909880,0.440484,-2.714566,3.064813,-3.019037,1,4,Ur III,Royal/Monumental
95085,0.798387,3.441451,-0.875406,-1.518119,2.518986,-2.295340,2.231613,1.615952,0.901302,2.897573,0.315274,2.655319,-0.702260,-1.005092,0.267886,1.466802,3,9,Old Babylonian,School
95086,4.048042,0.295159,2.476530,-1.279784,0.812973,-0.250022,1.787098,0.191454,-1.881001,-0.611311,-1.686032,1.790781,-0.915468,0.545573,-0.012969,-0.165567,3,9,Old Babylonian,School


In [25]:
df.to_csv(f"vae_encodings_and_data/vae_encoding_df_{DATE}_w_class_train.csv", index=False)

I will now repeat the process for the test set:

In [26]:
# Main execution
encodings, periods, genres = get_encodings_and_labels(ds_test, vae_model)

# Create dataframe
df_test = pd.DataFrame(encodings, columns=[f"X{i}" for i in range(1, encodings.shape[1] + 1)])
df_test['Period'] = periods
df_test['Genre'] = genres

100%|██████████| 16/16 [00:02<00:00,  5.48it/s]


In [27]:
df_test['Period_Name'] = df_test['Period'].map(PERIOD_INDICES)
df_test['Genre_Name'] = df_test['Genre'].map(GENRE_INDICES)

In [28]:
df_test.to_csv(f"vae_encodings_and_data/vae_encoding_df_{DATE}_w_class_test.csv", index=False)