In [None]:
!pip install streamlit
!npm install localtunnel
!pip install --upgrade --no-cache-dir gdown -q

In [None]:
%%writefile app.py
import streamlit as st
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchvision import transforms
from torch import nn
import gdown

gdown.download(id='1w3_GY-3-MhJ4M8wis-UZtZJIrLOz5d1F', quiet=True)

class TinyVGG(nn.Module):
  def __init__(self, input_shape: int, 
               hidden_units: int,
               output_shape: int) -> None:
    super().__init__()
    self.conv_block_1 = nn.Sequential(
        nn.Conv2d(in_channels=input_shape,
                  out_channels=hidden_units,
                  kernel_size=3,
                  stride=1,
                  padding=0),
        nn.ReLU(),
        nn.Conv2d(in_channels=hidden_units,
                  out_channels=hidden_units,
                  kernel_size=3,
                  stride=1,
                  padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,
                     stride=2)
    )
    self.conv_block_2 = nn.Sequential(
        nn.Conv2d(in_channels=hidden_units,
                  out_channels=hidden_units,
                  kernel_size=3,
                  stride=1,
                  padding=0),
        nn.ReLU(),
        nn.Conv2d(in_channels=hidden_units,
                  out_channels=hidden_units,
                  kernel_size=3,
                  stride=1,
                  padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,
                     stride=2)
    )
    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=hidden_units*13*13,
                  out_features=output_shape)
    )

  def forward(self, x):
    x = self.conv_block_1(x)
    x = self.conv_block_2(x)
    x = self.classifier(x)

    return x

categories = ['lasagna', 'pad_thai', 'pho', 'ramen']

@st.cache_resource
def load_model():
  model = TinyVGG(input_shape=3, hidden_units=15, output_shape=len(categories))
  ckp = torch.load('classify_food.pt', map_location=torch.device('cpu'))
  model.load_state_dict(ckp['model_state_dict'])
  preprocess = transforms.Compose([
                  transforms.Resize(size=(64,64)),
                  transforms.ToTensor()
])
  return model, preprocess

def predict(preprocess, model, img_path):
  with torch.inference_mode():
    img = preprocess(Image.open(img_path))
    prediction = model(img.unsqueeze(0)).squeeze(0).softmax(0)

  return prediction

st.title('Classify Food')
file_uploaded = st.file_uploader('Choose a file', type=['jpg', 'png', 'jpeg'])

if file_uploaded is not None:
  image = Image.open(file_uploaded)
  figure = plt.figure()
  plt.imshow(image)
  plt.axis('off')

  model, preprocess = load_model()

  prediction = predict(preprocess, model, file_uploaded)

  result_index = prediction.argmax().item()
  st.write('Predict:', categories[result_index])
  st.pyplot(figure)

Overwriting app.py


In [None]:
!streamlit run /content/app.py &>/content/logs.txt &

In [None]:
!npx localtunnel --port 8501

[K[?25hnpx: installed 22 in 3.488s
your url is: https://nine-loops-flash-35-231-174-156.loca.lt
^C
