In [1]:
import torch
from pathlib import Path
from torchvision import transforms
from utils import ids_to_tokens, img_transformation
from modelMobileNetV3 import Encoder, Decoder
from tqdm import tqdm
import pickle
from GUIconverter.GUIconverter import GUIconverter
from IPython.display import display, HTML, Image
from vocab import Vocab
from PIL import Image
from torchvision import transforms

In [2]:
# Configuration parameters
model_file_path = "./models/ED--epoch-1--loss-0.10679.pth" 
img_crop_size = 224
seed = 42

# Load the saved model
loaded_model = torch.load(model_file_path)
vocab = loaded_model['vocab']

embed_size = 64
hidden_size = 256
num_layers = 2

encoder = Encoder(embed_size)
decoder = Decoder(embed_size, hidden_size, len(vocab), num_layers)

# Load model weights
encoder.load_state_dict(loaded_model["encoder_model_state_dict"])
decoder.load_state_dict(loaded_model["decoder_model_state_dict"])

<All keys matched successfully>

In [3]:
encoder.eval()


Encoder(
  (mobilenet): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), str

In [4]:
decoder.eval()

Decoder(
  (embed): Embedding(17, 64)
  (lstm): LSTM(64, 256, num_layers=2, batch_first=True)
  (linear): Linear(in_features=256, out_features=17, bias=True)
)

In [5]:
# Load the image
image_path = './viewer.png'  # Change to your image's path
image = Image.open(image_path).convert('RGB')
transform = img_transformation(img_crop_size)
transformed_image = transform(image)

In [6]:
# Model prediction
features = encoder(transformed_image.unsqueeze(0))  # Unsqueeze to add batch dimension
predicted_ids = decoder.sample(features).cpu().data.numpy()
prediction = ids_to_tokens(vocab, predicted_ids)  # Assuming this function converts ids to tokens

# Convert to HTML
transpiler = GUIconverter(style='style6')
predicted_html_string = transpiler.transpile(prediction, insert_random_text=True)

In [7]:
def display_html_string(html_string):
    page = HTML(html_string)
    display(page)

In [8]:
display_html_string(predicted_html_string)