In [1]:
import io

from IPython.display import display
import ipywidgets as widgets
import numpy as np
import pandas as pd
from PIL import Image as PILImage
import torch

from era_data import TabletPeriodDataset, get_IDS
from VAE_model_tablets_class import VAE

# Hyperparameter

In [2]:
IMG_DIR = 'output/images_preprocessed'
BATCH_SIZE = 16
VERSION_NAME = 'period_clf_bs16_lr5e-05_beta_1_epochs_30-VAE-94936_samples-masked_w_classification_loss-equalpartsloss-March29'
IDS = get_IDS(IMG_DIR=IMG_DIR)
print(len(IDS))

94936


In [3]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)

class_weights = torch.load("data/class_weights_period.pt")

# Load model

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
checkpoint_path = f'lightning_logs/{VERSION_NAME}/checkpoints/epoch=29-step=170458.ckpt'
vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =5e-5, use_classification_loss=True, num_classes=num_classes,
            loss_type="weighted", class_weights=class_weights, device = device)

  self.class_weights = torch.tensor(class_weights).to(device)


# Load data

In [6]:
df_encodings_train = pd.read_csv(f'vae_encodings_and_data/vae_encoding_df_March30_w_class_train.csv')
df_encodings_test = pd.read_csv(f'vae_encodings_and_data/vae_encoding_df_March30_w_class_test.csv')

In [7]:
df_means = df_encodings_train.drop(["Period", "Genre", "Genre_Name", "CDLI_id"], axis = 1).groupby("Period_Name").mean().reset_index()

# Interpulation Widget

In [11]:
period_names = df_means['Period_Name'].unique()

image_widget = widgets.Image(layout=widgets.Layout(height='200px', width='200px', border='2px solid black'))
slider = widgets.FloatSlider(value=0, min=0, max=1, step=0.1, description='Interpolation')
dropdown1 = widgets.Dropdown(options=period_names, description='Period 1', layout=widgets.Layout(width='60%'))
dropdown2 = widgets.Dropdown(options=period_names, description='Period 2', layout=widgets.Layout(width='60%'))
interpolate_button = widgets.Button(description="Interpolate")

def get_image_from_period(period_name):
    period_data = torch.from_numpy(df_means[df_means["Period_Name"] == period_name].drop(["Period_Name"], axis=1).values[0].astype('float32'))
    return period_data

def generate_image(*args):
    image1 = get_image_from_period(dropdown1.value)
    image2 = get_image_from_period(dropdown2.value)

    i = slider.value
    new_tablet = (1-i) * image1 + i * image2
    new_tab_long = vae_model.fc3(new_tablet).unsqueeze(0)
    
    with torch.no_grad():
        generated_image = vae_model.decoder(new_tab_long)
    generated_image = generated_image[0][0].detach().cpu().numpy()
    generated_image = (generated_image * 255).astype(np.uint8)
    pil_img = PILImage.fromarray(generated_image)
    img_byte_arr = io.BytesIO()
    pil_img.save(img_byte_arr, format='PNG')
    image_widget.value = img_byte_arr.getvalue()

def reset_slider(*args):
    slider.value = 0
    
def interpolate_and_display(*args):
    reset_slider()
    generate_image()

slider.observe(generate_image, names='value')

interpolate_button.on_click(interpolate_and_display)

dropdowns = widgets.HBox([dropdown1, dropdown2], layout=widgets.Layout(justify_content='center'))

image_container = widgets.HBox([image_widget], layout=widgets.Layout(justify_content='center'))

controls_and_display = widgets.VBox([dropdowns, interpolate_button, slider, image_container], layout=widgets.Layout(align_items='center'))

display(controls_and_display)

VBox(children=(HBox(children=(Dropdown(description='Period 1', layout=Layout(width='60%'), options=('Achaemeni…