Skip to content

Commit

Permalink
Cleaning up code
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Mar 7, 2023
1 parent 607f90f commit cf349d5
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 216 deletions.
25 changes: 1 addition & 24 deletions clipft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio
import aiohttp
import clip
import logging
import numpy as np
Expand All @@ -9,27 +7,6 @@
from PIL import Image


async def get(url, session):
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.135 Safari/537.36 Edge/12.246"
}
try:
async with session.get(url=url, headers=headers) as response:
resp = await response.content.read()
return url, resp
except Exception as e:
logging.error(
"Unable to get url {} due to {}.".format(url, e.__class__)
)


async def async_download_image(img_urls):
async with aiohttp.ClientSession() as session:
ret = await asyncio.gather(*[get(url, session) for url in img_urls])
ret = [x for x in ret if x is not None]
return [x[0] for x in ret], [x[1] for x in ret]


class OutfitDataset(torch.utils.data.Dataset):
def __init__(self, img_blobs, captions, device):
_, preprocess_fn = clip.load("ViT-B/32", device=device)
Expand Down Expand Up @@ -101,7 +78,7 @@ def fine_tune_model(model, img_blobs, captions, batch_size=16, epochs=5):
)
best_loss = float("inf")

# Train model
# Train and validate model
for epoch in range(epochs):
# Train mode
model.train()
Expand Down
20 changes: 11 additions & 9 deletions dashboard.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import motion
import streamlit as st

import components
from motion import get_store
from transforms import SuggestIdea, Retrieval
from schemas import QuerySchema, CatalogSchema, QuerySource
from scrapers import scrape_everlane_sale


@st.cache_resource
def setup_database():
# Create store and add triggers
store = motion.get_store("fashion", create=True, memory=True)
store.addNamespace("query", components.QuerySchema)
store.addNamespace("catalog", components.CatalogSchema)
store = get_store("fashion", create=True, memory=True)
store.addNamespace("query", QuerySchema)
store.addNamespace("catalog", CatalogSchema)

store.addTrigger(
name="suggest_idea",
keys=["query.query"],
trigger=components.SuggestIdea,
trigger=SuggestIdea,
)
store.addTrigger(
name="retrieval",
keys=["catalog.img_blob", "query.text_suggestion", "query.feedback"],
trigger=components.Retrieval,
trigger=Retrieval,
)

# Add the catalog
components.scrape_everlane_sale(store, k=20)
scrape_everlane_sale(store, k=20)

return store

Expand All @@ -37,7 +39,7 @@ def run_query(query):
id=None,
key_values={
"query": query,
"src": components.QuerySource.ONLINE,
"src": QuerySource.ONLINE,
"query_id": query_id,
},
)
Expand Down
4 changes: 2 additions & 2 deletions motion/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def executeTrigger(
# Execute the transform lifecycle
if trigger_fn.shouldFit(id, trigger_elem):
trigger_fn.fit(id, trigger_elem)
if trigger_fn.shouldTransform(id, trigger_elem):
trigger_fn.transform(id, trigger_elem)
if trigger_fn.shouldInfer(id, trigger_elem):
trigger_fn.infer(id, trigger_elem)
logging.info(f"Finished running trigger {trigger_name}.")

def set(
Expand Down
4 changes: 2 additions & 2 deletions motion/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def fit(self, id, triggered_by: TriggerElement):
pass

@abstractmethod
def shouldTransform(self, id, triggered_by: TriggerElement):
def shouldInfer(self, id, triggered_by: TriggerElement):
pass

@abstractmethod
def transform(self, id, triggered_by: TriggerElement):
def infer(self, id, triggered_by: TriggerElement):
pass
33 changes: 33 additions & 0 deletions schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import motion
from typing import TypeVar


class Retailer(motion.MEnum):
NORDSTROM = "Nordstrom"
REVOLVE = "Revolve"
BLOOMINGDALES = "Bloomingdales"
EVERLANE = "Everlane"


class QuerySource(motion.MEnum):
OFFLINE = "Offline"
ONLINE = "Online"


class QuerySchema(motion.Schema):
src: QuerySource
query_id: int
query: str
text_suggestion: str
img_id: int
img_score: float
feedback: bool


class CatalogSchema(motion.Schema):
retailer: Retailer
img_url: str
img_blob: TypeVar("BLOB")
img_name: str
permalink: str
img_embedding: TypeVar("FLOAT[]")
87 changes: 87 additions & 0 deletions scrapers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import asyncio
import logging
import json
import pandas as pd
import requests

from utils import async_download_image

from bs4 import BeautifulSoup
from schemas import Retailer


def scrape_everlane_sale(store, k=20):
# Scrape the catalog and add the images to the store
urls = [
# "https://www.everlane.com/collections/womens-sale-2",
"https://www.everlane.com/collections/womens-all-tops",
"https://www.everlane.com/collections/womens-tees",
"https://www.everlane.com/collections/womens-sweaters",
"https://www.everlane.com/collections/womens-sweatshirts",
"https://www.everlane.com/collections/womens-bodysuits",
"https://www.everlane.com/collections/womens-jeans",
"https://www.everlane.com/collections/womens-bottoms",
"https://www.everlane.com/collections/womens-skirts-shorts",
"https://www.everlane.com/collections/womens-dresses",
"https://www.everlane.com/collections/womens-outerwear",
"https://www.everlane.com/collections/womens-underwear",
"https://www.everlane.com/collections/womens-perform",
"https://www.everlane.com/collections/swimwear",
"https://www.everlane.com/collections/womens-shoes",
]
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.135 Safari/537.36 Edge/12.246"
}
product_info = []
for url in urls:
r = requests.get(url=url, headers=headers)

soup = BeautifulSoup(r.content, "html5lib")

res = soup.find("script", attrs={"id": "__NEXT_DATA__"})
products = json.loads(res.contents[0])["props"]["pageProps"][
"fallbackData"
]["products"]

for product in products:
img_url = product["albums"]["square"][0]["src"]
img_name = product["displayName"]
permalink = product["permalink"]
product_info.append(
{
"img_url": img_url,
"img_name": img_name,
"permalink": "https://www.everlane.com/products/"
+ permalink,
}
)

# Delete duplicates
df = pd.DataFrame(product_info)
df = (
df.drop_duplicates(subset=["img_url"])
.sample(frac=1)
.reset_index(drop=True)
)
logging.info(f"Found {len(df)} unique products.")
df = df.head(k)

# Get blobs from the images
img_urls, contents = asyncio.run(
async_download_image(df["img_url"].values)
)
img_url_to_content = dict(zip(img_urls, contents))

for _, product_row in df.iterrows():
if product_row["img_url"] not in img_url_to_content:
continue

new_id = store.getNewId("catalog")
product = product_row.to_dict()
product.update(
{
"retailer": Retailer.EVERLANE,
"img_blob": img_url_to_content[product_row["img_url"]],
}
)
store.set("catalog", id=new_id, key_values=product)
47 changes: 0 additions & 47 deletions scratch.py

This file was deleted.

Loading

0 comments on commit cf349d5

Please sign in to comment.