In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
from PIL import Image


In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, emb_dim=128):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, self.n_patches + 1, emb_dim))

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)  # [B, C, H', W']
        x = x.flatten(2).transpose(1, 2)  # [B, N, C]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, emb_dim=128, n_heads=4, depth=4, dropout=0.1):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

    def forward(self, x):
        return self.transformer(x)

class ViTEncoder(nn.Module):
    def __init__(self, emb_dim=128):
        super().__init__()
        self.patch_embed = PatchEmbedding(224, 16, emb_dim)
        self.encoder = TransformerEncoder(emb_dim=emb_dim)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, emb_dim)
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.encoder(x)
        return self.mlp_head(x[:, 0])  # CLS token

class SiameseViT(nn.Module):
    def __init__(self, emb_dim=128):
        super().__init__()
        self.encoder = ViTEncoder(emb_dim=emb_dim)

    def forward(self, x1, x2):
        out1 = self.encoder(x1)
        out2 = self.encoder(x2)
        return out1, out2


In [None]:
class ChartMatchDataset(Dataset):
    def __init__(self, chart_dir, nonchart_dir, transform=None):
        self.data = []
        self.transform = transform

        for fname in os.listdir(chart_dir):
            if fname.endswith(('.png', '.jpg', '.jpeg')):
                self.data.append((os.path.join(chart_dir, fname), 1))

        for fname in os.listdir(nonchart_dir):
            if fname.endswith(('.png', '.jpg', '.jpeg')):
                self.data.append((os.path.join(nonchart_dir, fname), 0))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, label = self.data[idx]
        crop_img = Image.open(path).convert("RGB")
        page_img = Image.open("page.png").convert("RGB")  # Always use same page

        if self.transform:
            crop_img = self.transform(crop_img)
            page_img = self.transform(page_img)

        return crop_img, page_img, torch.tensor(label, dtype=torch.float)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ChartMatchDataset("train/cropped", "train/nonchart", transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

model = SiameseViT().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for img1, img2, labels in dataloader:
        img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

        optimizer.zero_grad()
        out1, out2 = model(img1, img2)
        sim = torch.cosine_similarity(out1, out2)
        loss = criterion(sim, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")


In [None]:
def extract_crop(input_img, template_img):
    ih, iw = input_img.shape[:2]
    th, tw = template_img.shape[:2]

    # If the template is too large, skip the matching
    if th > ih or tw > iw:
        print(f"⚠️ Skipping template (too large): {template_img.shape} > {input_img.shape}")
        return None, -1

    input_gray = cv2.cvtColor(input_img, cv2.COLOR_RGB2GRAY)
    template_gray = cv2.cvtColor(template_img, cv2.COLOR_RGB2GRAY)

    # Perform template matching
    result = cv2.matchTemplate(input_gray, template_gray, cv2.TM_CCOEFF_NORMED)
    _, max_val, _, max_loc = cv2.minMaxLoc(result)

    # Get the position of the match (top-left corner)
    h, w = template_gray.shape
    top_left = max_loc
    bottom_right = (top_left[0] + w, top_left[1] + h)

    # Crop the input image based on the match
    cropped = input_img[top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
    return cropped, max_val


In [None]:
def match_and_predict(input_img, template_img, model, transform, device):
    # Extract the cropped part that matches the chart
    cropped, max_val = extract_crop(input_img, template_img)

    if cropped is None:
        return None, max_val

    # Preprocess the cropped image for the model
    cropped_img = Image.fromarray(cropped)
    cropped_tensor = transform(cropped_img).unsqueeze(0).to(device)

    # Use the Siamese ViT model to get the embeddings of the cropped image and the template
    model.eval()
    with torch.no_grad():
        out1, out2 = model(cropped_tensor, template_tensor)

    # Compute similarity
    sim = torch.cosine_similarity(out1, out2)

    return sim, max_val

# Example usage
input_img = cv2.imread("page.png")  # Example input image
template_img = cv2.imread(r"C:\Users\User_Guest\Desktop\mainak\py\train\cropped\cropped_2.png")  # Example chart image (template)
template_tensor = transform(template_img).unsqueeze(0).to(device)  # Convert to tensor


sim, max_val = match_and_predict(input_img, template_img, model, transform, device)

if sim is not None:
    print(f"Cosine Similarity: {sim.item()}, Match Quality: {max_val}")
else:
    print("No match found.")
