In [None]:
from model import ImageMerger
import torch

# Load the trained model
device = torch.device("mps")
model_path = "model_epoch_25.pth"  # Adjust path if needed
checkpoint = torch.load(model_path, map_location=device)
model = ImageMerger().to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # Set model to evaluation mode

In [None]:
from dataset import getEmojiUrl
from PIL import Image
import requests
from torchvision import transforms
import matplotlib.pyplot as plt

# Function to load an image from URL and apply transformations
def load_image_from_url(url, size=128):
    transform = transforms.Compose([transforms.Resize((size, size)), transforms.ToTensor()])
    return transform(Image.open(requests.get(url, stream=True).raw))

# Load a few pairs of images
image_pairs = [
    ("1fa84", "2615"),
    ("1f600", "2615"),  # Add as many pairs as you want to check
]

# Plot and predict
for left_emoji, right_emoji in image_pairs:
    left_img_url = getEmojiUrl(left_emoji)
    right_img_url = getEmojiUrl(right_emoji)
    
    left_img = load_image_from_url(left_img_url).unsqueeze(0).to(device)  # Add batch dimension
    right_img = load_image_from_url(right_img_url).unsqueeze(0).to(device)  # Add batch dimension
    
    with torch.no_grad():
        output = model(left_img, right_img)
    
    # Displaying images
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    axs[0].imshow(transforms.ToPILImage()(left_img.squeeze(0).cpu()))
    axs[0].set_title("Left Image")
    axs[0].axis('off')
    
    axs[1].imshow(transforms.ToPILImage()(right_img.squeeze(0).cpu()))
    axs[1].set_title("Right Image")
    axs[1].axis('off')
    
    axs[2].imshow(transforms.ToPILImage()(output.squeeze(0).cpu()))
    axs[2].set_title("Predicted Merged Image")
    axs[2].axis('off')
    
    plt.show()