In [23]:
import torch
from torchvision import transforms
from PIL import Image
import cv2
import os
import numpy as np
from src.interfaces.super_resolution import TextSR  # Import the actual model architecture

import yaml
from easydict import EasyDict as edict

# Load config and args from YAML file
with open('C:/Users/vaibh/TextZoom/src/config/super_resolution.yaml', 'r') as file:
    config = edict(yaml.safe_load(file))

# Assume args are loaded similarly, or manually define them as needed
args = edict({
    'syn': False,
    'rec': 'aster',
    'demo_dir': './demo',
    'mixed': False,
    'resume': None, 
    'batch_size': 32,
    'test_data_dir': "C:/Users/vaibh/TextZoom/resized_images",
    'vis_dir': "C:/Users/vaibh/TextZoom/seg_image_preprocess",
    'mask': False 
})

# Create an instance of the model with config and args
model = TextSR(config=config, args=args)
# Load the state_dict into the model
state_dict = torch.load('checkpoints/model_best.pth', map_location=torch.device('cpu'))

# If the loaded object is a state_dict, load it into the model
model.model.load_state_dict(state_dict)

# Set model to evaluation mode
model.eval()


# Image preprocessing function
def preprocess_image(image_path):
    img = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((32, 128)),
        transforms.ToTensor()
    ])
    img = transform(img).unsqueeze(0)
    return img

# Image enhancement function
def enhance_image(model, image_path):
    img_tensor = preprocess_image(image_path)
    
    # Forward pass through the model
    with torch.no_grad():
        enhanced_img = model(img_tensor)
    
    # Post-process the output
    enhanced_img_np = enhanced_img.squeeze().cpu().numpy()
    enhanced_img_np = np.transpose(enhanced_img_np, (1, 2, 0))  # Convert back to HWC
    enhanced_img_np = (enhanced_img_np * 255).astype(np.uint8)  # Convert to uint8
    
    return enhanced_img_np

# Function to process all images in a folder and save them in another folder
def process_folder(input_folder, output_folder):
    # Create output folder if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    # Loop through all images in the input folder
    for filename in os.listdir(input_folder):
        if filename.endswith(('.jpg', '.png', '.jpeg')):  # Filter image files
            image_path = os.path.join(input_folder, filename)
            enhanced_image = enhance_image(model, image_path)
            
            # Save the enhanced image in the output folder
            output_path = os.path.join(output_folder, filename)
            cv2.imwrite(output_path, enhanced_image)
            print(f"Processed and saved: {output_path}")

# Define input and output folder paths
input_folder = 'resized_images'
output_folder = 'textzoom_transformation'

# Process images from input folder and save them to output folder
process_folder(input_folder, output_folder)


AttributeError: 'TextSR' object has no attribute 'model'