In [1]:
import numpy as np
from IPython.display import display, clear_output
from ipywidgets import widgets, Output
import torch
from PIL import Image
from era_data import TabletPeriodDataset, get_IDS
from VAE_model_tablets_class import VAE
import pandas as pd
from visualization_funcs import generate_image_from_VAE

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# IMG_DIR = 'output/images'
RUN_NAME_SUFFIX = '-masked_w_classification_loss' # ''
IMG_DIR = 'output/images_preprocessed'
LR = 5e-5
EPOCHS = 30
BATCH_SIZE = 16
SUFFIX = '-resnet50'
DATE = 'Oct2-v3'

In [3]:
IDS = get_IDS(IMG_DIR=IMG_DIR)
len(IDS)

VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_{EPOCHS}epochs{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}_blurvae-conv-{DATE}'

num_classes = len(TabletPeriodDataset.PERIOD_INDICES)

class_weights = torch.load("data/class_weights_period.pt")

In [4]:
chekpoint_path = f'lightning_logs/{VERSION_NAME}/checkpoints/epoch=29-step=407516.ckpt'

In [5]:
vae_model = VAE.load_from_checkpoint(chekpoint_path,image_channels=1,z_dim=16, lr =1e-5, use_classification_loss=True, num_classes=num_classes,
            loss_type="weighted", class_weights=class_weights, device = device)

  self.class_weights = torch.tensor(class_weights).to(device)


In [6]:
df_encodings_train = pd.read_csv('vae_encodings_and_data/vae_encoding_df_Oct2-v3_w_class_train.csv')
df_encodings_test = pd.read_csv('vae_encodings_and_data/vae_encoding_df_Oct2-v3_w_class_test.csv')

In [16]:
def generate_image(*args):

    vector = np.array([slider.value for slider in sliders])
    
    if len(vector) == 0:  # Checks if there's no data (vector is empty)
        with out:
            clear_output(wait=True)
            display(Image.new('RGB', (128, 128), color = 'white'))  # Display a blank/white image
    else:
        out.clear_output(wait=True)
    
        # Fetch values from sliders
        vector = np.array([slider.value for slider in sliders])
        
        # Convert to tensor and generate image using the model as you provided
    
        generated_image_tensor = generate_image_from_VAE(vector, vae_model)
            
        # Assuming the generated image is in the range [0, 1], so we multiply by 255
        generated_image_np = (generated_image_tensor.numpy() * 255).astype(np.uint8)
    
        # Convert to PIL Image and display inside the output widget
        image = Image.fromarray(generated_image_np)
        with out:
            display(image)

In [21]:
from ipywidgets import HBox, VBox, Dropdown, Output, FloatSlider

# Assuming df_encodings_train is already defined and includes 'Period_Name' and 'Genre_Name'

# Dropdown for period selection
period_options = ['None'] + sorted(df_encodings_train['Period_Name'].unique().tolist())
period_dropdown = Dropdown(options=period_options, value='None', description='Period:')

# Dropdown for genre selection
genre_options = ['None'] + sorted(df_encodings_train['Genre_Name'].unique().tolist())
genre_dropdown = Dropdown(options=genre_options, value='None', description='Genre:')

# Container for sliders
sliders = []
sliders_container = VBox([])  # Initially empty container

# Output widget for displaying the image
out = Output()

# Function to update the output based on dropdown selections
def update_output(*args):
    period = period_dropdown.value if period_dropdown.value != 'None' else None
    genre = genre_dropdown.value if genre_dropdown.value != 'None' else None
    
    imgs_to_disp_df = df_encodings_train.drop(['Genre', 'Period'], axis=1).copy()
    
    if period:
        imgs_to_disp_df = imgs_to_disp_df[imgs_to_disp_df['Period_Name'] == period]
    
    if genre:
        imgs_to_disp_df = imgs_to_disp_df[imgs_to_disp_df['Genre_Name'] == genre]

    if imgs_to_disp_df.empty:
        sliders.clear()
        sliders_container.children = []  # Clear the display of existing sliders
        generate_image()  # This will now display a blank image
    else:
        sliders.clear()  # Clear existing slider list
        initial_vector = imgs_to_disp_df.sample().drop(['Genre_Name', 'Period_Name'], axis=1).iloc[0].values.astype('float32')
        
        for i, val in enumerate(initial_vector):
            slider = FloatSlider(value=val, min=-3, max=3, step=0.01, description=f'Entry {i+1}')
            slider.observe(generate_image, 'value')
            sliders.append(slider)
        
        sliders_container.children = [slider for slider in sliders]  # Update the container with new sliders
        
        generate_image()

# Observing dropdowns for changes
period_dropdown.observe(update_output, 'value')
genre_dropdown.observe(update_output, 'value')

# Layout configuration
layout_container = HBox([sliders_container, out])

# Display widgets
display(period_dropdown, genre_dropdown, layout_container)

# Initialize
update_output()


Dropdown(description='Period:', options=('None', 'Achaemenid', 'ED I-II', 'ED IIIa', 'ED IIIb', 'Early Old Bab…

Dropdown(description='Genre:', options=('None', 'Administrative', 'Astronomical', 'Legal', 'Letter', 'Lexical'…

HBox(children=(VBox(), Output()))

Conclusions?

* Entry 1: how wide the tablet is + number of angles shown / rectangular vs round?
* Entry 2: is the image placed to the left or not / how big is the centerpiece + thickness of tablet
* Entry 3: gap between the angles + is the shape more square?
* Entry 4: is the bottom left of the tablet is chipped, and how much + thickness
* Entry 5:  is the top left of the tablet is chipped, and how much + narrow side views
* Entry 6: how round the tablet is / smaller side centerpiece?
* Entry 7: if the picture is taken so that one of the sides is placed on top, or not
* Entry 8: how round is the tablet 
* Entry 9: is the tablet places on the left / right + how narrow the tablet is (is there some shadow ion the right?)
* Entry 10: is it one piece showing, or all angles
* Entry 11: are there many many displays, or just the 6 sides (e.g. cylinder or multiple fragments) 
* Entry 12: not so clear - seems like related to the number of fragments / how short the tablet is
* Entry 13: not so clear - seems like related to the number of fragments / how wide the tablet is (is there some shadow ion the left?)
* Entry 14: not so clear - seems like related to the number of fragments / bottom right of centerpience chipped
* Entry 15: not so clear - seems like related to the number of fragments
* Entry 16: is it one piece showing, or all angles