In [166]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
import cv2
from sklearn.metrics.pairwise import cosine_similarity

In [167]:
augmentation_count = 5

In [168]:
def preprocess_image(image_path, size=(224, 224), aug_count=augmentation_count):
    preprocessed = []
    flag = 0
    for _ in range(aug_count):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5*flag),  # 0.5
            transforms.RandomRotation(degrees=30*flag),  # 30
            transforms.ColorJitter(brightness=0.2*flag, contrast=0.2*flag, saturation=0.2*flag, hue=0.1*flag),
            transforms.Resize(size),
            transforms.CenterCrop(224), 
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        preprocessed.append(transform(image).unsqueeze(0))
        flag = 1
    return preprocessed

In [169]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        model = models.vit_b_16(pretrained=True)
        self.feature_extractor = model 
        self.feature_extractor.heads = nn.Identity() 
    
    def forward(self, x):
        x = self.feature_extractor(x) 
        return x

In [170]:
def get_feature_vector(image_path, aug_count=augmentation_count):
    features = []
    image_tensor = preprocess_image(image_path, aug_count=aug_count)
    with torch.no_grad():
        for img in image_tensor:
            feature = extractor(img)
            features.append(feature.numpy())
    return features

In [171]:
def compare_logos(query_image, reference_images, threshold=0.25):
    query_vector = get_feature_vector(query_image, 1)[0]
    similarities = []
    for ref in reference_images:
        ref_vector = get_feature_vector(ref)
        for ref_v in ref_vector:
            sim = cosine_similarity(query_vector, ref_v)[0][0]
            # print(sim)
            similarities.append(sim)
    best_match = max(similarities)
    return best_match, best_match >= threshold

In [172]:
extractor = FeatureExtractor()
extractor.eval()



FeatureExtractor(
  (feature_extractor): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
    

In [173]:
query_logo = "./logos/query_logo_6.png"  # логотип из видео
reference_logos = ["./logos/reference_logo_1.jpg", "./logos/reference_logo_2.png", "./logos/reference_logo_3.jpg", 
                   "./logos/reference_logo_4.jpg", "./logos/reference_logo_5.jpg", "./logos/reference_logo_6.jpg", 
                   "./logos/reference_logo_7.jpg", "./logos/reference_logo_8.jpg", "./logos/reference_logo_9.png", 
                   "./logos/reference_logo_10.png", "./logos/reference_logo_11.jpg", "./logos/reference_logo_12.png"
]  # референсы логотипов


In [176]:
similarity_score, is_match = compare_logos(query_logo, reference_logos)
print(f"Логотип соответствует референсному набору: {is_match}")
print(f"Уровень сходства: {similarity_score:.2f}")

Логотип соответствует референсному набору: True
Уровень сходства: 0.40
