In [None]:
from PIL import Image
import requests
from transformers import AutoProcessor, BlipForQuestionAnswering
import os
import pandas as pd


In [None]:
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")


In [None]:
text = "List all the measurements in this image. For example, 100 grams or 1000 kgs. List only the measurement, nothing else. Print <NO> if the measurement cant be found"


In [None]:
df = pd.read_csv("train.csv")
sample = df.iloc[:300]

test_df = pd.read_csv("test.csv")
test_sample = test_df.iloc[:50]

In [None]:
# Train
for (idx, (url, id, ent_type, ent_val)) in enumerate(sample.values):
  print(f"{idx}: {url}")
  image = Image.open(requests.get(url, stream=True).raw)

  text = f"List the {ent_type} measurement in this image. In the format of 100 grams or 1000 kgs, etc. List only the measurement, nothing else. Print <NO> if the measurement cant be found"

  label = ent_val
  inputs = processor(images=image, text=text, return_tensors="pt")
  labels = processor(text=label, return_tensors="pt").input_ids

  inputs["labels"] = labels
  outputs = model(**inputs)
  loss = outputs.loss
  loss.backward()

0: https://m.media-amazon.com/images/I/61I9XdN6OFL.jpg
1: https://m.media-amazon.com/images/I/71gSRbyXmoL.jpg
2: https://m.media-amazon.com/images/I/61BZ4zrjZXL.jpg
3: https://m.media-amazon.com/images/I/612mrlqiI4L.jpg
4: https://m.media-amazon.com/images/I/617Tl40LOXL.jpg
5: https://m.media-amazon.com/images/I/61QsBSE7jgL.jpg
6: https://m.media-amazon.com/images/I/81xsq6vf2qL.jpg
7: https://m.media-amazon.com/images/I/71DiLRHeZdL.jpg
8: https://m.media-amazon.com/images/I/91Cma3RzseL.jpg
9: https://m.media-amazon.com/images/I/71jBLhmTNlL.jpg
10: https://m.media-amazon.com/images/I/81N73b5khVL.jpg
11: https://m.media-amazon.com/images/I/61oMj2iXOuL.jpg
12: https://m.media-amazon.com/images/I/91LPf6OjV9L.jpg
13: https://m.media-amazon.com/images/I/81fOxWWWKYL.jpg
14: https://m.media-amazon.com/images/I/81dzao1Ob4L.jpg
15: https://m.media-amazon.com/images/I/91-iahVGEDL.jpg
16: https://m.media-amazon.com/images/I/81S2+GnYpTL.jpg
17: https://m.media-amazon.com/images/I/81e2YtCOKvL.jpg
18

In [None]:
# Inference
for (idx, (index, url, id, ent_type)) in enumerate(test_sample.values):
  print(f"{idx} | {ent_type} : {url}")
  image = Image.open(requests.get(url, stream=True).raw)

  text = f"List the {ent_type} measurement in this image. In the format of 100 grams or 1000 kgs, etc. List only the measurement, nothing else"

  inputs = processor(images=image, text=text, return_tensors="pt")
  outputs = model.generate(**inputs)
  print(processor.decode(outputs[0], skip_special_tokens=True))

0 | height : https://m.media-amazon.com/images/I/110EibNyclL.jpg




30000
1 | width : https://m.media-amazon.com/images/I/11TU2clswzL.jpg
no
2 | height : https://m.media-amazon.com/images/I/11TU2clswzL.jpg
no
3 | depth : https://m.media-amazon.com/images/I/11TU2clswzL.jpg
no
4 | depth : https://m.media-amazon.com/images/I/11gHj8dhhrL.jpg
no
5 | height : https://m.media-amazon.com/images/I/11gHj8dhhrL.jpg
no
6 | width : https://m.media-amazon.com/images/I/11gHj8dhhrL.jpg
no
7 | height : https://m.media-amazon.com/images/I/11lshEUmCrL.jpg
30000
8 | width : https://m.media-amazon.com/images/I/21+i52HRW4L.jpg
10
9 | height : https://m.media-amazon.com/images/I/21-LmSmehZL.jpg
100 m
10 | item_weight : https://m.media-amazon.com/images/I/213oP6n7jtL.jpg
1. 0
11 | width : https://m.media-amazon.com/images/I/213wY3gUsmL.jpg
30 cm
12 | depth : https://m.media-amazon.com/images/I/214CLs1oznL.jpg
10000
13 | height : https://m.media-amazon.com/images/I/214CLs1oznL.jpg
10000
14 | width : https://m.media-amazon.com/images/I/214CLs1oznL.jpg
10
15 | item_weight : http

KeyboardInterrupt: 