In [1]:
import torch
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool
from era_data import TabletPeriodDataset, get_IDS
from VAE_model_tablets_class import VAE
import numpy as np
from PIL import Image
from torchvision import transforms

In [2]:
# IMG_DIR = 'output/images'
RUN_NAME_SUFFIX = '-masked_w_classification_loss' # ''
IMG_DIR = 'output/images_preprocessed'
LR = 5e-5
EPOCHS = 30
BATCH_SIZE = 16
SUFFIX = '-resnet50'
DATE = 'Oct2-v3'

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
IDS = get_IDS(IMG_DIR=IMG_DIR)
len(IDS)

97640

In [5]:
VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_{EPOCHS}epochs{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}_blurvae-conv-{DATE}'
VERSION_NAME

'period_clf_bs16_lr5e-05_30epochs-resnet50-97640_samples-masked_w_classification_loss_blurvae-conv-Oct2-v3'

In [6]:
VERNAME = 'period_clf_bs16_lr5e-05_15epochs-resnet50-97640_samples-masked-Sept_19_blur-early_stopping-1'

train_ids = pd.read_csv(f'output/clf_ids/period-train-{VERNAME}.csv', header=None)[0].astype(str)
test_ids = pd.read_csv(f'output/clf_ids/period-test-{VERNAME}.csv', header=None)[0].astype(str)

In [7]:
ds_train = TabletPeriodDataset(IDS=train_ids, IMG_DIR=IMG_DIR, mask=True)
ds_test = TabletPeriodDataset(IDS=test_ids, IMG_DIR=IMG_DIR, mask=True)

Filtering 97640 IDS down to provided 97140...
Filtering 97640 IDS down to provided 500...


In [8]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
num_classes

22

In [9]:
class_weights = torch.load("data/class_weights_period.pt")

In [10]:
chekpoint_path = f'lightning_logs/{VERSION_NAME}/checkpoints/epoch=29-step=407516.ckpt'
vae_model = VAE.load_from_checkpoint(chekpoint_path,image_channels=1,z_dim=16, lr =1e-5, use_classification_loss=True, num_classes=num_classes,
            loss_type="weighted", class_weights=class_weights, device = device)

  self.class_weights = torch.tensor(class_weights).to(device)


In [11]:
# Function to preprocess an image
def preprocess_image(img, label, genre):
    img = np.array(Image.fromarray(img).resize((178, 218), Image.NEAREST))  # Resize the image
    return img, label, genre

def process_batch(batch, vae_model, device):
    # Preprocess each image in the batch
    preprocessed_batch = [preprocess_image(img, label, genre) for img, label, genre in batch]
    imgs, periods, genres = zip(*preprocessed_batch)

    # Convert list of images to torch tensor and move to the device
    img_tensors = torch.stack([transforms.ToTensor()(img) for img in imgs]).to(device)

    with torch.no_grad():
        encodings = vae_model.representation(img_tensors)
    encodings = encodings.cpu().numpy()

    return encodings, periods, genres

In [12]:
def get_encodings_and_labels(dataset, vae_model, device='cuda', batch_size=32):
    vae_model.to(device)
    vae_model.eval()

    all_encodings = []
    all_periods = []
    all_genres = []

    num_batches = (len(dataset) + batch_size - 1) // batch_size

    for i in tqdm(range(num_batches)):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, len(dataset))
        batch = [dataset[j] for j in range(start_idx, end_idx)]
        encodings, periods, genres = process_batch(batch, vae_model, device)
        all_encodings.extend(encodings)
        all_periods.extend(periods)
        all_genres.extend(genres)

    return np.array(all_encodings), np.array(all_periods), np.array(all_genres)

In [13]:
# Main execution
encodings, periods, genres = get_encodings_and_labels(ds_train, vae_model)

# Create dataframe
df = pd.DataFrame(encodings, columns=[f"X{i}" for i in range(1, encodings.shape[1] + 1)])
df['Period'] = periods
df['Genre'] = genres

100%|██████████| 2972/2972 [09:27<00:00,  5.23it/s]


In [14]:
PERIOD_INDICES = {

    0: 'other',
    1: 'Ur III',
    2: 'Neo-Assyrian',
    3: 'Old Babylonian',
    4: 'Middle Babylonian',
    5: 'Neo-Babylonian',
    6: 'Old Akkadian',
    7: 'Achaemenid',
    8: 'Early Old Babylonian',
    9: 'ED IIIb',
    10: 'Middle Assyrian',
    11: 'Old Assyrian',
    12: 'Uruk III',
    13: 'Proto-Elamite',
    14: 'Lagash II',
    15: 'Ebla',
    16: 'ED IIIa',
    17: 'Hellenistic',
    18: 'ED I-II',
    19: 'Middle Elamite',
    20: 'Middle Hittite',
    21: 'Uruk IV'
}

In [15]:
GENRE_INDICES = {
        
        1: 'Administrative',
        2: 'Letter',
        3: 'Legal',
        4: 'Royal/Monumental',
        5: 'Literary',
        6: 'Lexical',
        7: 'Omen',
        8: 'uncertain',
        9: 'School',
        10: 'Mathematical',
        11: 'Prayer/Incantation',
        12: 'Scientific',
        13: 'Ritual',
        14: 'fake (modern)',
        15: 'Astronomical',
        16: 'Private/Votive',
    }

In [16]:
df['Period_Name'] = df['Period'].map(PERIOD_INDICES)

In [17]:
df['Genre_Name'] = df['Genre'].map(GENRE_INDICES)

In [18]:
df

Unnamed: 0,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13,X14,X15,X16,Period,Genre,Period_Name,Genre_Name
0,0.695620,-0.326637,1.869451,1.120124,-3.230606,1.078223,-2.286818,-0.213600,-2.957026,0.395746,0.686147,2.156204,1.430726,-0.226425,0.980338,-1.268349,1,1,Ur III,Administrative
1,0.488771,-0.505401,2.806634,-0.011833,-3.098880,0.062783,-1.171744,0.740006,-3.654613,0.205812,1.805598,1.383725,0.076415,-0.259734,0.640475,-2.054135,1,1,Ur III,Administrative
2,-0.164733,-0.477150,0.875877,1.744388,-1.792805,1.181497,-1.963444,-1.465748,-3.112858,1.478121,0.527789,-0.672934,0.073764,-0.231725,1.831303,-1.136168,1,1,Ur III,Administrative
3,0.032329,0.733997,1.136432,0.980828,-0.611278,0.674751,-2.201123,-0.782052,-3.235670,-0.053801,2.460866,0.277234,1.209431,-0.855141,1.686832,-1.203823,1,1,Ur III,Administrative
4,-2.117302,0.836887,2.241112,-2.359340,-0.441417,0.800746,0.511654,-1.293421,-2.536816,-0.735655,1.384294,-0.431612,1.212609,2.165188,0.814998,-1.704666,1,1,Ur III,Administrative
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95083,0.239176,-1.923309,-2.530270,-1.414568,-2.259939,3.442629,-1.488173,3.042659,-1.457338,2.377469,-1.447516,-0.033868,-5.858948,-0.102388,-0.006196,-2.442894,1,1,Ur III,Administrative
95084,1.455029,-2.804560,-1.519120,-0.844761,0.119485,4.548380,-3.053493,2.669134,-1.779485,3.313068,-4.855679,-0.840578,-1.970150,-2.371216,-2.190817,-2.478782,1,4,Ur III,Royal/Monumental
95085,-1.905614,2.131715,1.252587,-1.206944,-1.331858,-1.755407,0.406348,0.143557,1.187520,-2.829051,-0.125474,-1.720152,-1.169059,-2.028571,3.087566,-1.271153,3,9,Old Babylonian,School
95086,-0.179766,-0.862341,1.356933,-0.642790,0.994281,-0.645911,-0.318751,-1.104957,-0.800633,0.576582,-0.364576,1.692541,0.775255,-0.038733,4.185235,-1.850742,3,9,Old Babylonian,School


In [19]:
df.to_csv(f"vae_encodings_and_data/vae_encoding_df_{DATE}_w_class_train.csv", index=False)

I will now repeat the process for the test set:

In [20]:
# Main execution
encodings, periods, genres = get_encodings_and_labels(ds_test, vae_model)

# Create dataframe
df_test = pd.DataFrame(encodings, columns=[f"X{i}" for i in range(1, encodings.shape[1] + 1)])
df_test['Period'] = periods
df_test['Genre'] = genres

100%|██████████| 16/16 [00:03<00:00,  5.25it/s]


In [21]:
df_test['Period_Name'] = df_test['Period'].map(PERIOD_INDICES)
df_test['Genre_Name'] = df_test['Genre'].map(GENRE_INDICES)

In [22]:
df_test.to_csv(f"vae_encodings_and_data/vae_encoding_df_{DATE}_w_class_test.csv", index=False)