In [1]:
import torch
from modelscope import AutoModel

MODEL_NAME = "/home/public/dkx/model/BAAI/BGE-VL-v1.5-zs"

model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
model.eval()
model.cuda()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LLaVANextForEmbedding(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPSdpaAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELU

In [2]:
from pymilvus import MilvusClient

client = MilvusClient(
    uri="http://localhost:19530"
)

client.load_collection("test")

In [3]:
client.list_collections()

['test', 'bar']

In [4]:
with torch.no_grad():
    model.set_processor(MODEL_NAME)

    query_inputs = model.data_process(
        text="""
Agriculture and Food Production, Production of Different Crops In An Area, metric tons, leptokurtic, horizontal
Crop category,metric tons
Rapeseed,20000
Wheat,2694
Apple,2488
Sugarcane,7573
""",
        q_or_c="q",
        task_instruction="Recommend the most suitable chart with corresponding description for visualizing the information given by the provided text: "
    )

    query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]

    query_embs = torch.nn.functional.normalize(query_embs, dim=-1)

    print(len(query_embs.cpu().detach().tolist()[0]))

4096


In [63]:
text_search_results = client.search(
    collection_name="test",
    anns_field="text_dense",
    data=query_embs.cpu().detach().tolist(),
    limit=10,
    search_params={"metric_type": "IP"},
    output_fields=["id", "image_url", "data", "text_dense", "img_dense"],  # specifies fields to be returned
)

for text_search_result in text_search_results[0]:
    print(text_search_result["id"])
    print(text_search_result["distance"])


21
267.5112609863281
273
242.0521240234375
395
231.0052490234375
684
224.5597686767578
84
220.9500732421875
454
219.68634033203125
911
217.21353149414062
157
214.20343017578125
742
211.1539764404297
540
210.75335693359375


In [64]:
img_search_results = client.search(
    collection_name="test",
    anns_field="img_dense",
    data=query_embs.cpu().detach().tolist(),
    limit=10,
    search_params={"metric_type": "IP"},
    output_fields=["id", "image_url", "data", "text_dense", "img_dense"],  # specifies fields to be returned
)

for img_search_result in img_search_results[0]:
    print(img_search_result["id"])
    print(img_search_result["distance"])

911
111.25525665283203
273
111.1042251586914
742
108.90589904785156
141
108.09735107421875
21
106.78152465820312
84
104.885986328125
454
103.28755950927734
540
101.20672607421875
395
98.39408874511719
157
96.54594421386719


In [57]:
from pymilvus import AnnSearchRequest

# text semantic search (dense)
request_1 = AnnSearchRequest(
    data=query_embs.cpu().detach().tolist(),
    anns_field="text_dense",
    param={
        "metric_type": "IP"
    },
    limit=10
)

# text-to-image search (multimodal)
request_2 = AnnSearchRequest(
    data=query_embs.cpu().detach().tolist(),
    anns_field="img_dense",
    param={
        "metric_type": "IP"
    },
    limit=10
)

reqs = [request_1, request_2]

In [58]:
from pymilvus import RRFRanker, WeightedRanker

rrf_ranker = RRFRanker(100)
weighed_ranker = WeightedRanker(0, 1)

In [67]:
hybrid_search_results = client.hybrid_search(
    collection_name="test",  # target collection
    reqs=reqs,
    ranker=weighed_ranker,
    limit=10,  # number of returned entities
    output_fields=["id", "image_url", "metadata"],  # specifies fields to be returned
)

In [68]:
for hybrid_search_result in hybrid_search_results[0]:
    print(hybrid_search_result["id"])
    print(hybrid_search_result["metadata"])

911
{'distribution': ' step', 'display': ' vertical', 'header': ['Crop category', 'metric tons'], 'unit': ' metric tons'}
273
{'distribution': ' leptokurtic', 'display': ' vertical', 'header': ['Crop category', 'metric tons'], 'unit': ' metric tons'}
742
{'distribution': ' bimodal', 'display': ' vertical', 'header': ['Crop category', 'metric tons'], 'unit': ' metric tons'}
141
{'distribution': ' step', 'display': ' vertical', 'header': ['Crop category', 'metric tons'], 'unit': ' metric tons'}
21
{'distribution': ' leptokurtic', 'display': ' horizontal', 'header': ['Crop category', 'metric tons'], 'unit': ' metric tons'}
84
{'distribution': ' bimodal', 'display': ' vertical', 'header': ['Crop category', 'metric tons'], 'unit': ' metric tons'}
454
{'distribution': ' long_tail', 'display': ' vertical', 'header': ['Crop category', 'metric tons'], 'unit': ' metric tons'}
540
{'distribution': ' long_tail', 'display': ' horizontal', 'header': ['Crop category', 'metric tons'], 'unit': ' metric

In [14]:
with torch.no_grad():
    model.set_processor(MODEL_NAME)

    query_inputs1 = model.data_process(
        text="""
Agriculture and Food Production, Production of Different Crops In An Area, metric tons, leptokurtic, horizontal
Crop category,metric tons
Rapeseed,20000
Wheat,2694
Apple,2488
Sugarcane,7573
""",
        q_or_c="q",
        task_instruction="Recommend the most suitable chart with corresponding description for visualizing the information given by the provided text: "
    )

    query_embs1 = model(**query_inputs1, output_hidden_states=True)[:, -1, :]

    query_embs1 = torch.nn.functional.normalize(query_embs1, dim=-1)

    print(len(query_embs1.cpu().detach().tolist()[0]))

4096


In [15]:
res1 = client.search(
    collection_name="bar",  # target collection
    data=query_embs1.cpu().detach().tolist(),  # query vectors
    limit=10,  # number of returned entities
    output_fields=["id", "image_url", "data"],  # specifies fields to be returned
)

for x in res1[0]:
    print(x["id"])
    print(x["distance"])

911
0.3492613136768341
273
0.34827539324760437
141
0.3397336006164551
742
0.3390982151031494
84
0.3294185698032379
21
0.327197790145874
454
0.3228279948234558
540
0.31269580125808716
395
0.3067818284034729
157
0.302957147359848
