<a href="https://colab.research.google.com/github/dyc2424748461/TPGSR/blob/main/TPGSR_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPGSR: Text Prior Guided Scene Text Image Super-Resolution

This notebook demonstrates how to run TPGSR (Text Prior Guided Scene Text Image Super-Resolution) in Google Colab.

Paper: [Text Prior Guided Scene Text Image Super-resolution](https://arxiv.org/abs/2106.15368)

## 1. Environment Setup

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

In [None]:
# Clone the repository
!git clone https://github.com/dyc2424748461/TPGSR.git
%cd TPGSR

In [None]:
# Install dependencies
!pip install torch==1.2.0 torchvision==0.4.0
!pip install numpy==1.18.0
!pip install Pillow==6.2.2
!pip install lmdb easydict pyfasttext editdistance tensorboardX
!pip install pyyaml scipy matplotlib tqdm opencv-python
!pip install IPython

## 2. Download Dataset and Pretrained Models

In [None]:
# Download TextZoom dataset
import gdown
import os

# Create data directory
os.makedirs('data', exist_ok=True)

# Download TextZoom dataset from Google Drive
print("Downloading TextZoom dataset...")
gdown.download('https://drive.google.com/uc?id=1WKVhB2qFjqQUqy8KVqtgEQZCsn2hZ8kV', 'data/TextZoom.zip', quiet=False)

# Extract dataset
!cd data && unzip -q TextZoom.zip
print("TextZoom dataset downloaded and extracted!")

In [None]:
# Download pretrained recognizer models
import gdown
import os

# Create pretrained directory
os.makedirs('pretrained', exist_ok=True)

# Download ASTER model
print("Downloading ASTER model...")
gdown.download('https://drive.google.com/uc?id=1sOqiX9cqOgXV0qbMHTwl5eSV_5_d1gwc', 'pretrained/aster.pth.tar', quiet=False)

# Download MORAN model
print("Downloading MORAN model...")
gdown.download('https://drive.google.com/uc?id=1YLDHhtc5EyRNyhvNQS6ywC9htkdT4c7q', 'pretrained/moran.pth', quiet=False)

# Download CRNN model
print("Downloading CRNN model...")
gdown.download('https://drive.google.com/uc?id=1ooaHefQp0wDATLvOZlsXyLCjjWiHSHKX', 'pretrained/crnn.pth', quiet=False)

print("All pretrained models downloaded!")

## 3. Configuration Setup

In [None]:
# Update configuration file for Colab environment
import yaml
from easydict import EasyDict

# Load configuration
with open('config/super_resolution.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.Loader)

# Update paths for Colab
config['TRAIN']['train_data_dir'] = [
    '/content/TPGSR/data/TextZoom/train1',
    '/content/TPGSR/data/TextZoom/train2'
]

config['TRAIN']['VAL']['val_data_dir'] = [
    '/content/TPGSR/data/TextZoom/test/easy',
    '/content/TPGSR/data/TextZoom/test/medium',
    '/content/TPGSR/data/TextZoom/test/hard'
]

config['TRAIN']['VAL']['rec_pretrained'] = '/content/TPGSR/pretrained/aster.pth.tar'
config['TRAIN']['VAL']['moran_pretrained'] = '/content/TPGSR/pretrained/moran.pth'
config['TRAIN']['VAL']['crnn_pretrained'] = '/content/TPGSR/pretrained/crnn.pth'

# Adjust for Colab GPU/CPU
if torch.cuda.is_available():
    config['TRAIN']['cuda'] = True
    config['TRAIN']['batch_size'] = 16  # Adjust based on GPU memory
else:
    config['TRAIN']['cuda'] = False
    config['TRAIN']['batch_size'] = 4

config['TRAIN']['workers'] = 2
config['TRAIN']['epochs'] = 100  # Reduce for demo

# Save updated configuration
with open('config/super_resolution.yaml', 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print("Configuration updated for Colab environment!")
print(f"CUDA enabled: {config['TRAIN']['cuda']}")
print(f"Batch size: {config['TRAIN']['batch_size']}")

## 4. Code Modifications for Compatibility

In [None]:
# Fix ptflops import issues
import re

# Comment out ptflops imports and usage in interfaces/base.py
with open('interfaces/base.py', 'r') as f:
    content = f.read()

# Comment out ptflops related lines
content = re.sub(r'^(.*ptflops.*)$', r'# \1', content, flags=re.MULTILINE)
content = re.sub(r'^(.*get_model_complexity_info.*)$', r'# \1', content, flags=re.MULTILINE)

with open('interfaces/base.py', 'w') as f:
    f.write(content)

# Do the same for interfaces/super_resolution.py
with open('interfaces/super_resolution.py', 'r') as f:
    content = f.read()

content = re.sub(r'^(.*ptflops.*)$', r'# \1', content, flags=re.MULTILINE)
content = re.sub(r'^(.*get_model_complexity_info.*)$', r'# \1', content, flags=re.MULTILINE)

with open('interfaces/super_resolution.py', 'w') as f:
    f.write(content)

print("Code modifications completed!")

In [None]:
# Fix model loading for CPU/GPU compatibility
import re

with open('interfaces/base.py', 'r') as f:
    content = f.read()

# Add map_location for torch.load calls
device_str = "torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
content = re.sub(
    r'torch\.load\(([^)]+)\)',
    f'torch.load(\\1, map_location={device_str})',
    content
)

# Fix MORAN initialization for CPU/GPU
if torch.cuda.is_available():
    content = re.sub(
        r"inputDataType='torch\.FloatTensor', CUDA=False",
        "inputDataType='torch.cuda.FloatTensor', CUDA=True",
        content
    )
else:
    content = re.sub(
        r"inputDataType='torch\.cuda\.FloatTensor', CUDA=True",
        "inputDataType='torch.FloatTensor', CUDA=False",
        content
    )

with open('interfaces/base.py', 'w') as f:
    f.write(content)

print("Device compatibility fixes applied!")

## 5. Demo Run

In [None]:
# Create demo directory and test image
import os
from PIL import Image, ImageDraw, ImageFont
import numpy as np

os.makedirs('demo', exist_ok=True)
os.makedirs('demo_results', exist_ok=True)

# Create a simple test image with text
img = Image.new('RGB', (128, 32), color='white')
draw = ImageDraw.Draw(img)

try:
    # Try to use a system font
    font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 20)
except:
    # Fallback to default font
    font = ImageFont.load_default()

draw.text((10, 5), "HELLO", fill='black', font=font)
img.save('demo/test.png')

print("Demo test image created!")
img.show()

In [None]:
# Create and run demo script
demo_script = '''
import os
import torch
from PIL import Image
import torchvision.transforms as transforms
import yaml
from easydict import EasyDict

# Load configuration
config_path = os.path.join('config', 'super_resolution.yaml')
config = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)
config = EasyDict(config)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create a simple transform
transform = transforms.Compose([
    transforms.Resize((32, 128)),
    transforms.ToTensor(),
])

# Load a test image
img_path = 'demo/test.png'
img = Image.open(img_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0)

print(f"Loaded image from {img_path}")
print(f"Image tensor shape: {img_tensor.shape}")

# Save the transformed image
transformed_img = transforms.ToPILImage()(img_tensor.squeeze(0))
transformed_img.save('demo_results/input.png')

print("Demo preprocessing completed successfully!")
print("Input image saved to demo_results/input.png")
'''

with open('run_demo.py', 'w') as f:
    f.write(demo_script)

# Run the demo
!python run_demo.py

## 6. Training (Optional)

In [None]:
# Start training (this will take a long time)
# Uncomment the following line to start training
# !python main.py --batch_size=16 --STN --mask --gradient --sr_share --use_distill --without_colorjitter --test_model=TSRN

print("Training command prepared. Uncomment the line above to start training.")
print("Note: Training will take several hours/days depending on your hardware.")

## 7. Inference with Pretrained Model

In [None]:
# Download pretrained TPGSR model (if available)
# Note: You may need to train your own model or find a pretrained one
print("To run inference, you need a trained TPGSR model.")
print("You can either:")
print("1. Train your own model using the training section above")
print("2. Download a pretrained model if available")
print("3. Use the demo script above for basic functionality testing")

## 8. View Results

In [None]:
# Display results
import matplotlib.pyplot as plt
from PIL import Image
import os

# Show input image
if os.path.exists('demo_results/input.png'):
    img = Image.open('demo_results/input.png')
    plt.figure(figsize=(10, 3))
    plt.imshow(img)
    plt.title('Input Image')
    plt.axis('off')
    plt.show()
    
print("Demo completed! Check the demo_results folder for output images.")

## 9. Cleanup (Optional)

In [None]:
# Clean up large files to save space
# Uncomment the following lines if you want to clean up

# !rm -rf data/TextZoom.zip
# !rm -rf data/TextZoom
# print("Cleanup completed!")

print("Cleanup commands prepared. Uncomment to clean up large files.")