In [14]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageEnhance
from torchvision import transforms
import json
import math

# Load vocabulary from tokenizer.json
def load_vocab():
    with open('tokenizer.json', 'r', encoding='utf-8') as f:
        data = json.load(f)
        return data['vocab']


In [15]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=500):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        from torchvision.models import mobilenet_v3_large

        # Initialize MobileNetV3 without pretrained weights
        mobilenet = mobilenet_v3_large(weights=None)
        
        # Modify first conv layer to accept single channel input
        mobilenet.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)

        # Remove the classifier
        self.features = mobilenet.features
        self.linear = nn.Linear(960, embed_size)

    def forward(self, images):
        features = self.features(images)
        features = features.permute(0, 2, 3, 1)  # [batch_size, height, width, channels]
        features = features.view(features.size(0), -1, features.size(-1))  # [batch_size, seq_len, channels]
        features = self.linear(features)
        return features

class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_layers=6, nhead=8, dim_feedforward=1024, dropout=0.1):
        super(DecoderTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, dropout)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers
        )
        self.fc = nn.Linear(embed_size, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, enc_out, tgt, tgt_mask=None):
        if tgt_mask is None:
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        tgt = self.embedding(tgt)
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(
            tgt.permute(1, 0, 2),
            enc_out.permute(1, 0, 2),
            tgt_mask=tgt_mask
        )
        output = output.permute(1, 0, 2)
        output = self.fc(output)
        return output

class Im2LatexModel(nn.Module):
    def __init__(self, embed_size, vocab_size, **kwargs):
        super(Im2LatexModel, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderTransformer(embed_size, vocab_size, **kwargs)

    def forward(self, images, formulas, formula_mask=None):
        features = self.encoder(images)
        outputs = self.decoder(features, formulas, formula_mask)
        return outputs

    def generate(self, image, start_token, end_token, max_len=200, beam_size=6):
        with torch.no_grad():
            features = self.encoder(image.unsqueeze(0))
            # Initialize beam search
            beams = [(torch.tensor([[start_token]], device=image.device), 0.0)]
            completed_beams = []

            for _ in range(max_len):
                candidates = []

                for seq, score in beams:
                    if seq[0, -1].item() == end_token:
                        completed_beams.append((seq, score))
                        continue

                    # Get predictions for next token
                    out = self.decoder(features, seq)
                    logits = out[:, -1, :]
                    probs = F.log_softmax(logits, dim=-1)

                    # Get top-k candidates
                    values, indices = probs[0].topk(beam_size)
                    for value, idx in zip(values, indices):
                        new_seq = torch.cat([seq, idx.unsqueeze(0).unsqueeze(0)], dim=1)
                        new_score = score + value.item()
                        candidates.append((new_seq, new_score))

                # Select top beam_size candidates
                candidates = sorted(candidates, key=lambda x: x[1], reverse=True)
                beams = candidates[:beam_size]

                # Early stopping if all beams are completed
                if len(completed_beams) >= beam_size:
                    break

            # Add incomplete beams to completed list
            completed_beams.extend(beams)

            # Return sequence with highest score
            best_seq = max(completed_beams, key=lambda x: x[1])[0]

            # Remove both start and end tokens
            final_seq = []
            for token in best_seq.squeeze(0)[1:].tolist():  # Skip start token
                if token == end_token:  # Stop at end token
                    break
                final_seq.append(token)

            return final_seq

In [16]:
vocab = load_vocab()
reverse_vocab = {str(idx): word for word, idx in vocab.items()}
    
    # Initialize model
model = Im2LatexModel(
        embed_size=256,
        vocab_size=len(vocab),
        num_layers=6,
        nhead=8,
        dim_feedforward=1024,
        dropout=0.1
    )
    
    # Load trained weights
checkpoint_path = r"C:\Users\huyho\OneDrive\Máy tính\im2latex_llm\best_model.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
model.eval()

  model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))


Im2LatexModel(
  (encoder): EncoderCNN(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): Hardswish()
      )
      (1): InvertedResidual(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
            (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (1): Conv2dNormActivation(
            (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          )
        )
      )
      (2): InvertedResidual(
        (block): Sequential(
          (0): Conv2dNormActivati

In [17]:
import torch
import torch.nn as nn
from typing import List, Tuple
import torch.utils.mobile_optimizer as mobile_optimizer
class TracedEncoder(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.encoder.eval()

    def forward(self, image):
        return self.encoder(image.unsqueeze(0))

class TracedDecoder(nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder
        self.decoder.eval()

    def forward(self, features, tokens):
        # Ensure tokens are long type
        tokens = tokens.long()
        return self.decoder(features, tokens)

def convert_model():
    """Convert encoder and decoder separately"""
    # Load original model
    original_model = Im2LatexModel(
        embed_size=256,
        vocab_size=len(vocab),
        num_layers=6,
        nhead=8,
        dim_feedforward=1024,
        dropout=0.1
    )
    
    # Load weights
    checkpoint_path = "best_model.pth"
    original_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    original_model.eval()
    
    # Trace encoder
    traced_encoder = TracedEncoder(original_model.encoder)
    example_image = torch.randn(1, 150, 700)
    traced_encoder = torch.jit.trace(traced_encoder, example_image)
    # traced_encoder = mobile_optimizer.optimize_for_mobile(traced_encoder)

    traced_decoder = TracedDecoder(original_model.decoder)
    example_features = traced_encoder(example_image)
    example_tokens = torch.zeros((1, 1), dtype=torch.long)
    traced_decoder = torch.jit.trace(traced_decoder, (example_features, example_tokens))
    # traced_decoder = mobile_optimizer.optimize_for_mobile(traced_decoder)
    # Save models
    # traced_encoder._save_for_lite_interpreter("encoder_traced.ptl")
    # traced_decoder._save_for_lite_interpreter("decoder_traced.ptl")
    traced_encoder.save("encoder_traced.ptl")
    traced_decoder.save("decoder_traced.ptl")
    
    print("Models converted and saved successfully!")
    
    return traced_encoder, traced_decoder

In [18]:
encoder, decoder = convert_model()

  original_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
  if a.grad is not None:


Models converted and saved successfully!


Test model convert


In [19]:
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((150, 700)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: torch.where(x > 0.5, 1.0, 0.0)),
    ])
    
    image = Image.open(image_path).convert('L')
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance(2.0)
    image = transform(image)
    
    if torch.mean(image) > 0.5:
        image = 1 - image
        
    return image
def decode_prediction(tokens, reverse_vocab):
    words = []
    for token in tokens:
        word = reverse_vocab.get(str(token))
        if word not in ['<PAD>', '<START>', '<END>', '<UNK>']:
            words.append(word)
    return ' '.join(words)

In [20]:
import json
def load_vocab():
    with open('tokenizer.json', 'r', encoding='utf-8') as f:
        data = json.load(f)
        return data['vocab']
vocab = load_vocab()
reverse_vocab = {str(idx): word for word, idx in vocab.items()}

In [21]:
def inference_with_traced_models(encoder, decoder, image, start_token, end_token, max_len=200, beam_size=5):
    """
    Perform inference using beam search with the traced encoder and decoder
    
    Args:
        encoder: Traced encoder model
        decoder: Traced decoder model
        image: Input image tensor
        start_token: Token ID for sequence start
        end_token: Token ID for sequence end
        max_len: Maximum sequence length
        beam_size: Size of beam for search
        
    Returns:
        List of tokens representing the best sequence
    """
    with torch.no_grad():
        # Get image features
        features = encoder(image)
        
        # Initialize beams: (sequence, score)
        beams = [(torch.tensor([[start_token]], dtype=torch.long, device=image.device), 0.0)]
        completed_beams = []
        
        # Beam search
        for _ in range(max_len):
            candidates = []
            
            # Expand each beam
            for seq, score in beams:
                # If sequence completed, add to completed beams
                if seq[0, -1].item() == end_token:
                    completed_beams.append((seq, score))
                    continue
                    
                # Get predictions
                out = decoder(features, seq)
                logits = out[:, -1, :]
                probs = torch.nn.functional.log_softmax(logits, dim=-1)
                
                # Get top-k candidates
                values, indices = probs[0].topk(beam_size)
                for value, idx in zip(values, indices):
                    new_seq = torch.cat([seq, torch.tensor([[idx.item()]], dtype=torch.long, device=seq.device)], dim=1)
                    new_score = score + value.item()
                    candidates.append((new_seq, new_score))
            
            # Sort and select top-k candidates
            candidates = sorted(candidates, key=lambda x: x[1], reverse=True)
            beams = candidates[:beam_size]
            
            # Early stopping if enough complete sequences
            if len(completed_beams) >= beam_size:
                break
        
        # Add incomplete sequences to completed beams
        completed_beams.extend(beams)
        
        # Select best sequence
        best_seq = max(completed_beams, key=lambda x: x[1])[0]
        
        # Convert to list of tokens (excluding start token)
        return [token.item() for token in best_seq.squeeze(0)[1:] if token.item() != end_token]

In [22]:
encoder = torch.jit.load("encoder_traced.ptl")
decoder = torch.jit.load("decoder_traced.ptl")
    

In [23]:
import torch
import torch.nn as nn
from typing import List, Tuple
from torchvision import transforms
from PIL import Image, ImageEnhance
def test_converted_models(test_image_path: str):
    """Test the converted models with a sample image"""
    encoder = torch.jit.load("encoder_traced.ptl")
    decoder = torch.jit.load("decoder_traced.ptl")
    
    image = preprocess_image(test_image_path)
    
    tokens = inference_with_traced_models(
        encoder, 
        decoder,
        image,
        start_token=vocab['<START>'],
        end_token=vocab['<END>'],
        beam_size=5 
    )
    
    latex = decode_prediction(tokens, reverse_vocab)
    print(f"Predicted LaTeX:\n{latex}")

if __name__ == "__main__":
    test_converted_models("image_test/image copy.png")

Predicted LaTeX:
\displaystyle x = \frac { - b \pm \sqrt { b ^ { 2 } - 4 a c } } { 2 a }
