In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torchvision import transforms
from collections import Counter
from torch.utils.data import Dataset, DataLoader
# ========================== #
# 1️⃣ Đọc dữ liệu từ CSV
# ========================== #
data_dir = "/kaggle/input/im2latex-premium"  # Thay bằng đường dẫn đúng
train_path = os.path.join(data_dir, "total_train.csv")
validate_path = os.path.join(data_dir, "val2Ftotal_val.csv")
# test_path = os.path.join(data_dir, "im2latex_test.csv")
image_dir = os.path.join(data_dir, "root/images")  # Thư mục chứa ảnh

# Đọc danh sách công thức từ im2latex_formulas.norm.csv
# formulas_df = pd.read_csv(formulas_path)

# Đọc tập huấn luyện, validation, test
train_df = pd.read_csv(train_path)
validate_df = pd.read_csv(validate_path)
# test_df = pd.read_csv(test_path)

In [2]:
import os
import numpy as np
import pandas as pd
from PIL import Image, ImageEnhance  # Add ImageEnhance here
import torch
from torchvision import transforms
from collections import Counter
from torch.utils.data import Dataset, DataLoader
class Im2LatexDataset(Dataset):
    def __init__(self, df, img_dir, max_len=200, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.max_len = max_len
        self.transform = transform if transform else transforms.Compose([
            transforms.Grayscale(num_output_channels=1),  # Convert to grayscale first
            transforms.Resize((150, 700)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: torch.where(x > 0.5, 1.0, 0.0)),  # Binarize image
        ])

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image_filename']
        formula = self.df.iloc[idx]['latex']
        formula = '<START> ' + formula + ' <END>'

        # Load image
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('L')  # Convert to grayscale explicitly

        # Enhance contrast
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(2.0)

        # Apply transforms
        image = self.transform(image)

        # Invert if needed (ensure dark text on light background)
        if torch.mean(image) > 0.5:
            image = 1 - image

        return image, formula

    def __len__(self):
        return len(self.df)

In [3]:
from collections import Counter

class Tokenizer:
    def __init__(self, formulas, min_freq=5):
        self.pad_token = '<PAD>'
        self.start_token = '<START>'
        self.end_token = '<END>'
        self.unk_token = '<UNK>'

        # Tạo vocabulary
        word_counts = Counter()
        for formula in formulas:
            words = formula.split()
            word_counts.update(words)

        # Lọc từ có tần suất > min_freq
        self.vocab = {self.pad_token: 0, self.start_token: 1,
                     self.end_token: 2, self.unk_token: 3}

        idx = len(self.vocab)
        for word, count in word_counts.items():
            if count >= min_freq:
                self.vocab[word] = idx
                idx += 1

        self.reverse_vocab = {idx: word for word, idx in self.vocab.items()}

    def encode(self, formula, max_len=200):
        words = formula.split()
        ids = [self.vocab.get(word, self.vocab[self.unk_token]) for word in words]

        # Padding
        if len(ids) < max_len:
            ids = ids + [self.vocab[self.pad_token]] * (max_len - len(ids))
        else:
            ids = ids[:max_len]

        return torch.LongTensor(ids)

    def decode(self, ids):
        words = [self.reverse_vocab[id.item()] for id in ids
                if id.item() not in [self.vocab[self.pad_token]]]
        return ' '.join(words)

In [4]:
import torch
import torch.nn as nn
import torchvision.models as models
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=500):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights

        efficientnet = mobilenet_v3_large(weights=None)

        # Modify first conv layer to accept single channel input
        efficientnet.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)

        # Remove the classifier
        self.features = efficientnet.features
        self.linear = nn.Linear(960, embed_size)

    def forward(self, images):
        features = self.features(images)
        features = features.permute(0, 2, 3, 1)  # [batch_size, height, width, channels]
        features = features.view(features.size(0), -1, features.size(-1))  # [batch_size, seq_len, channels]
        features = self.linear(features)
        return features
class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_layers=6,
                 nhead=8, dim_feedforward=1024, dropout=0.1):
        super(DecoderTransformer, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, dropout)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers
        )

        self.fc = nn.Linear(embed_size, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, enc_out, tgt, tgt_mask=None):
        if tgt_mask is None:
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

        tgt = self.embedding(tgt)
        tgt = self.pos_encoder(tgt)

        output = self.transformer_decoder(
            tgt.permute(1, 0, 2),
            enc_out.permute(1, 0, 2),
            tgt_mask=tgt_mask
        )

        output = output.permute(1, 0, 2)
        output = self.fc(output)
        return output

class Im2LatexModel(nn.Module):
    def __init__(self, embed_size, vocab_size, **kwargs):
        super(Im2LatexModel, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderTransformer(embed_size, vocab_size, **kwargs)

    def forward(self, images, formulas, formula_mask=None):
        features = self.encoder(images)
        outputs = self.decoder(features, formulas, formula_mask)
        return outputs

    def generate(self, image, start_token, end_token, max_len=200, beam_size=5):
        with torch.no_grad():
            # Encode image
            features = self.encoder(image.unsqueeze(0))

            # Initialize beam search
            beams = [(torch.tensor([[start_token]], device=image.device), 0.0)]
            completed_beams = []

            for _ in range(max_len):
                candidates = []

                for seq, score in beams:
                    if seq[0, -1].item() == end_token:
                        completed_beams.append((seq, score))
                        continue

                    # Get predictions for next token
                    out = self.decoder(features, seq)
                    logits = out[:, -1, :]
                    probs = F.log_softmax(logits, dim=-1)

                    # Get top-k candidates
                    values, indices = probs[0].topk(beam_size)
                    for value, idx in zip(values, indices):
                        new_seq = torch.cat([seq, idx.unsqueeze(0).unsqueeze(0)], dim=1)
                        new_score = score + value.item()
                        candidates.append((new_seq, new_score))

                # Select top beam_size candidates
                candidates = sorted(candidates, key=lambda x: x[1], reverse=True)
                beams = candidates[:beam_size]

                # Early stopping if all beams are completed
                if len(completed_beams) >= beam_size:
                    break

            # Add incomplete beams to completed list
            completed_beams.extend(beams)

            # Return sequence with highest score
            best_seq = max(completed_beams, key=lambda x: x[1])[0]

            # Remove both start and end tokens
            final_seq = []
            for token in best_seq.squeeze(0)[1:].tolist():  # Skip start token
                if token == end_token:  # Stop at end token
                    break
                final_seq.append(token)

            return final_seq

In [7]:
import torch.cuda.amp as amp
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
import os
def train_model(model, train_loader, val_loader, tokenizer,
               num_epochs=3, device='cuda'):

    # Initialize wandb first
    wandb.init(
        project="im2latex_best_v1",
        settings=wandb.Settings(init_timeout=1200000),
        config={
            "epochs": num_epochs,
            "batch_size": train_loader.batch_size * torch.cuda.device_count(), # Update batch size
            "learning_rate": 0.0004,
            "architecture": "Efnet-Transformer",
            "dataset_size": len(train_loader.dataset),
            "num_gpus": torch.cuda.device_count()
        }
    )

    # Create checkpoint directory after wandb is initialized
    checkpoint_dir = os.path.join(wandb.run.dir, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)

    scaler = amp.GradScaler()
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.vocab['<PAD>'])

    # Tăng learning rate theo số lượng GPU
    lr = 0.0008
    optimizer = Adam(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-5)
    model = model.to(device)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        epoch_start_time = time.time()

        print(f"\nEpoch {epoch+1}/{num_epochs}")

        for batch_idx, (images, formulas) in enumerate(train_loader):
            images = images.to(device)
            target_seqs = torch.stack([
                tokenizer.encode(f) for f in formulas
            ]).to(device)

            input_seqs = target_seqs[:, :-1]
            target_seqs = target_seqs[:, 1:]

            with amp.autocast():
                outputs = model(images, input_seqs)
                loss = criterion(
                    outputs.reshape(-1, outputs.size(-1)),
                    target_seqs.reshape(-1)
                )

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"\rBatch [{batch_idx}/{len(train_loader)}] Loss: {loss.item():.4f}", end="")
                current_lr = optimizer.param_groups[0]['lr']
                wandb.log({
                    "train_batch_loss": loss.item(),
                    "learning_rate": current_lr,
                    "epoch": epoch,
                    "batch": batch_idx,
                    "num_gpus": torch.cuda.device_count()
                }
            )
                scheduler.step(loss)

        avg_train_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0

        with torch.no_grad():
            with amp.autocast():
                for images, formulas in val_loader:
                    images = images.to(device)
                    target_seqs = torch.stack([
                        tokenizer.encode(f) for f in formulas
                    ]).to(device)

                    input_seqs = target_seqs[:, :-1]
                    target_seqs = target_seqs[:, 1:]

                    outputs = model(images, input_seqs)
                    loss = criterion(
                        outputs.reshape(-1, outputs.size(-1)),
                        target_seqs.reshape(-1)
                    )
                    val_loss += loss.item()

        val_loss /= len(val_loader)
        epoch_time = time.time() - epoch_start_time

        print(f"\nTime: {epoch_time:.1f}s | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f}")
        wandb.log({
            "train_loss": avg_train_loss,
            "val_loss": val_loss,
            "learning_rate": optimizer.param_groups[0]['lr'],
            "epoch": epoch
        })
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            # Use wandb run directory for checkpoint path
            checkpoint_path = os.path.join(
                checkpoint_dir,
                f'best_model.pth'
            )
            model_to_save = model.module if hasattr(model, 'module') else model
            torch.save(model_to_save.state_dict(), checkpoint_path)
            # Log best model artifact to wandb
            artifact = wandb.Artifact(
                name=f"model-checkpoint-epoch{epoch}",
                type="model",
                description=f"Model checkpoint from epoch {epoch} with val_loss {val_loss:.4f}"
            )
                        # Add file to artifact after ensuring it exists
            if os.path.exists(checkpoint_path):
                artifact.add_file(checkpoint_path)
                wandb.log_artifact(artifact)
                print(f"\nSaved new best model checkpoint with val_loss: {val_loss:.4f}")
            else:
                print(f"\nWarning: Failed to save checkpoint at {checkpoint_path}")

    wandb.finish()

In [None]:
import wandb
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import numpy as np

# Set wandb API key
os.environ["WANDB_API_KEY"] = "bab1bf4bf3d565e4e79b437fd9c484d62ac878c9"

if __name__ == "__main__":
    # Initialize dataset and dataloader
    train_dataset = Im2LatexDataset(train_df, image_dir)
    val_dataset = Im2LatexDataset(validate_df, image_dir)

    batch_size = 260
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )

    # Initialize tokenizer
    tokenizer = Tokenizer(train_df["latex"].values)

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Create the base model first
    base_model = Im2LatexModel(
        embed_size=256,
        vocab_size=len(tokenizer.vocab),
        num_layers=6,
        nhead=8,
        dim_feedforward=1024,
        dropout=0.1
    )

    # Count and print parameters before wrapping with DataParallel
    encoder_params = count_parameters(base_model.encoder)
    decoder_params = count_parameters(base_model.decoder)
    total_params = count_parameters(base_model)

    print(f"\nModel Parameter Counts:")
    print("-" * 40)
    print(f"Encoder: {encoder_params:,} parameters")
    print(f"Decoder (Transformer): {decoder_params:,} parameters")
    print(f"Total: {total_params:,} parameters")
    print("-" * 40)

    print(f"\nParameter Distribution:")
    print(f"Encoder: {encoder_params/total_params*100:.1f}%")
    print(f"Decoder: {decoder_params/total_params*100:.1f}%")

    # Now wrap the model with DataParallel
    if torch.cuda.device_count() > 1:
        print(f"\nUsing {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(base_model)
    else:
        model = base_model

    # Train model
    train_model(model, train_loader, val_loader, tokenizer)


Model Parameter Counts:
----------------------------------------
Encoder: 3,217,680 parameters
Decoder (Transformer): 8,044,833 parameters
Total: 11,262,513 parameters
----------------------------------------

Parameter Distribution:
Encoder: 28.6%
Decoder: 71.4%

Using 2 GPUs!


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhuyhoangak4[0m ([33mhoangvbck[0m). Use [1m`wandb login --relogin`[0m to force relogin


  scaler = amp.GradScaler()



Epoch 1/3


  with amp.autocast():


Batch [2000/13222] Loss: 0.4671