In [None]:
#@title requirements.txt
%%writefile requirements.txt

streamlit
transformers
torch
Pillow
numpy

In [None]:
#@title install
!pip install -q -r requirements.txt

In [None]:
#@title app.py version 0.1
%%writefile app.py



# imports

from transformers import FlavaModel, BertTokenizer, FlavaFeatureExtractor
import numpy as np
from PIL import Image
import torch
import streamlit as st

st.header('Image Classification')

model = FlavaModel.from_pretrained("facebook/flava-full")
model.eval()
fe = FlavaFeatureExtractor.from_pretrained("facebook/flava-full")
tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")



def load_image(image_file):
	img = Image.open(image_file)
	return img


def shot2(img, labels_text):
    # im = img
    PIL_image = np.asarray(img)

    # PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
    labels = labels_text.split(",")
    label_with_template = [f"This is a photo of a {label}" for label in labels]
    image_input = fe([PIL_image], return_tensors="pt")
    text_inputs = tokenizer(label_with_template, padding="max_length", return_tensors="pt")

    image_embeddings = model.get_image_features(**image_input)[:, 0, :]
    text_embeddings = model.get_text_features(**text_inputs)[:, 0, :]
    similarities = list(torch.nn.functional.softmax((text_embeddings @ image_embeddings.T).squeeze(0), dim=0))
    res =  {label: similarities[idx].item() for idx, label in enumerate(labels)}

    # res = shot2('/content/71xkI-PIE5L._SL1500_.jpg', 'car, not car, baby, fruit, vegetable, building')
    return max(res, key=res.get)


image_file = st.sidebar.file_uploader("Upload Images", type=["png","jpg","jpeg"])

with st.sidebar.expander(''):
    possible_labels = st.text_area('Possible labels:', 'car, map, not car, baby, fruit, vegetable, building')

if image_file is not None:

        # To See details
        file_details = {"filename":image_file.name, "filetype":image_file.type,
                        "filesize":image_file.size}
        with st.expander('Show picture details'):
            st.write(file_details)

        # To View Uploaded Image
        st.header('Show picture and predict label')
        with st.spinner():
            st.image(load_image(image_file),width=250)
            picture = load_image(image_file)
            detected_name = shot2(picture, possible_labels)
            if 'car' in str(detected_name).lower():
                label = 'vehicle'
                st.success(label)
            else:
                label = 'non-vehicle'
                st.error(label)

                with st.expander('Show what can be on the pic'):
                    st.info(detected_name)

In [None]:
%%writefile app.py
#@title app.py version 0.2


# imports

from transformers import FlavaModel, BertTokenizer, FlavaFeatureExtractor
import numpy as np
from PIL import Image
import torch
import streamlit as st

st.header('Image Classification')

model = FlavaModel.from_pretrained("facebook/flava-full")
model.eval()
fe = FlavaFeatureExtractor.from_pretrained("facebook/flava-full")
tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")



def load_image(image_file):
	img = Image.open(image_file)
	return img

@st.cache()
def predict(img, labels_text):
    # im = img
    PIL_image = np.asarray(img)

    # PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
    labels = labels_text.split(",")
    label_with_template = [f"This is a photo of a {label}" for label in labels]
    image_input = fe([PIL_image], return_tensors="pt")
    text_inputs = tokenizer(label_with_template, padding="max_length", return_tensors="pt")

    image_embeddings = model.get_image_features(**image_input)[:, 0, :]
    text_embeddings = model.get_text_features(**text_inputs)[:, 0, :]
    similarities = list(torch.nn.functional.softmax((text_embeddings @ image_embeddings.T).squeeze(0), dim=0))
    res =  {label: similarities[idx].item() for idx, label in enumerate(labels)}

    # res = shot2('/content/71xkI-PIE5L._SL1500_.jpg', 'car, not car, baby, fruit, vegetable, building')
    return max(res, key=res.get).replace(' ', '')


image_file = st.sidebar.file_uploader("Upload Images", type=["png","jpg","jpeg"])

with st.sidebar.expander(''):
    possible_labels = st.text_area('Possible labels:', 'map, baby, fruit, vegetable, building, person, furniture, tree, sea')
    target_labels = st.text_area('Vehicle-related words:', 'bus, limousine, pickup, SUV, sedan')
    target_labels_list =  target_labels.split(', ')
    
    possible_labels = str(set(possible_labels.split(',') + target_labels.split(',')) ).replace('  ', ' ').replace("'", "").replace('{', '').replace('}', '')


if image_file is not None:

    # To See details
    file_details = {"filename":image_file.name, "filetype":image_file.type,
                    "filesize":image_file.size}
    with st.expander('Show picture details'):
        st.write(file_details)

    # To View Uploaded Image
    st.header('Show picture and predict label')

    st.image(load_image(image_file),width=250)
    picture = load_image(image_file)
    detected_name = predict(picture, possible_labels)


    if str(detected_name) in str(target_labels_list):
        label = 'vehicle'
        st.success(label)
        with st.expander('Show what kind of vehicle it can be'):
            st.info(detected_name)
    else:
        label = 'non-vehicle'
        st.error(label)
        with st.expander('Show what can be on the pic'):
            st.info(detected_name)


        with st.expander('Show what can be on the pic'):
            st.info(detected_name)


In [None]:
#@title Generate link to web application

! streamlit run app.py & npx localtunnel --port 8501 