In [57]:
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
from ipywidgets import widgets, Output
import torch
from PIL import Image
from era_data import TabletPeriodDataset, get_IDS
from VAE_model_tablets_class import VAE
import pandas as pd

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# IMG_DIR = 'output/images'
RUN_NAME_SUFFIX = '-masked_w_classification_loss' # ''
IMG_DIR = 'output/images_preprocessed'
LR = 5e-5
EPOCHS = 30
BATCH_SIZE = 16
SUFFIX = '-resnet50'
DATE = 'Oct2-v3'

In [6]:
IDS = get_IDS(IMG_DIR=IMG_DIR)
len(IDS)

VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_{EPOCHS}epochs{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}_blurvae-conv-{DATE}'

num_classes = len(TabletPeriodDataset.PERIOD_INDICES)

class_weights = torch.load("data/class_weights_period.pt")

In [7]:
chekpoint_path = f'lightning_logs/{VERSION_NAME}/checkpoints/epoch=29-step=407516.ckpt'

In [10]:
vae_model = VAE.load_from_checkpoint(chekpoint_path,image_channels=1,z_dim=16, lr =1e-5, use_classification_loss=True, num_classes=num_classes,
            loss_type="weighted", class_weights=class_weights, device = device)

  self.class_weights = torch.tensor(class_weights).to(device)


In [13]:
df_encodings_train = pd.read_csv('vae_encodings_and_data/vae_encoding_df_Oct2-v3_w_class_train.csv')
df_encodings_test = pd.read_csv('vae_encodings_and_data/vae_encoding_df_Oct2-v3_w_class_test.csv')

In [75]:
def generate_image(*args):
    # Clear the output widget
    out.clear_output(wait=True)

    # Fetch values from sliders
    vector = np.array([slider.value for slider in sliders])
    
    # Convert to tensor and generate image using the model as you provided
    input_tensor = torch.from_numpy(vector).unsqueeze(0).float()
    new_tab_long = vae_model.fc3(input_tensor)

    with torch.no_grad():
        generated_image_tensor = vae_model.decoder(new_tab_long)
        # Assuming the generated image is in the range [0, 1], so we multiply by 255
        generated_image_np = (generated_image_tensor.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        generated_image_np = generated_image_np.squeeze(-1)
    # Convert to PIL Image and display inside the output widget
    image = Image.fromarray(generated_image_np)
    with out:
        display(image)

In [78]:
out = Output()  # Output widget to display the image

sliders = []

period = 'ED I-II' # can be None to include just genre
genre = '' # can be None to include just period

imgs_to_disp_df = df_encodings_train.drop(['Genre', 'Period'], axis = 1).copy()

if period:
     imgs_to_disp_df = imgs_to_disp_df[imgs_to_disp_df['Period_Name']==period]
    
if genre:
     imgs_to_disp_df = imgs_to_disp_df[imgs_to_disp_df['Genre_Name']==genre]
    
initial_vector = imgs_to_disp_df.sample().drop(['Genre_Name', 'Period_Name'], axis = 1).iloc[0].values.astype('float32')

for i, val in enumerate(initial_vector):
    slider = widgets.FloatSlider(value=val, min=-3, max=3, step=0.01, description=f'Entry {i+1}')
    slider.observe(generate_image, 'value')
    sliders.append(slider)
    
# Show sliders
for slider in sliders:
    display(slider)

# Show initial image
display(out)
generate_image()

FloatSlider(value=0.35628050565719604, description='Entry 1', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=0.1713632345199585, description='Entry 2', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=0.24868468940258026, description='Entry 3', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=0.4406053125858307, description='Entry 4', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=-0.7359985709190369, description='Entry 5', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=-1.511257290840149, description='Entry 6', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=3.0, description='Entry 7', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=1.7927883863449097, description='Entry 8', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=-1.6779472827911377, description='Entry 9', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=-1.1771475076675415, description='Entry 10', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=-1.486824631690979, description='Entry 11', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=1.8727010488510132, description='Entry 12', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=0.5334859490394592, description='Entry 13', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=-0.918647289276123, description='Entry 14', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=-0.7044676542282104, description='Entry 15', max=3.0, min=-3.0, step=0.01)

FloatSlider(value=0.9711862206459045, description='Entry 16', max=3.0, min=-3.0, step=0.01)

Output()