# Install packages

In [None]:
!pip install openpyxl

# Import packages

In [None]:
import io
import os
import requests

import pandas as pd
from PIL import Image, ImageFile
from sklearn.model_selection import train_test_split
from tqdm import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True

%load_ext tensorboard

# Set paths

In [None]:
root = "/dbfs/sketches"
img_path = "images"
label_path = "Training_data_Jan2022.xlsx"
label_file = "sketches_mcml.csv"

fill_str = "None"

# Read data

In [None]:
sketch_df = pd.read_excel(os.path.join(root, label_path))
sketch_df.head()

In [None]:
print(f"Unique URL's {sketch_df['Image URL'].nunique()}")
print(f"Unique ID's {sketch_df['Image Id'].nunique()}")
print(f"Unique Product number's {sketch_df['Product Number'].nunique()}")
print(f"Unique Name's {sketch_df['Image Name'].nunique()}")

# Download data

In [None]:
def fetch_images(
    data: pd.DataFrame,
    dir_path: str,
    img_path: str,
    img_ext: str = "jpg"
) -> None:
    """Fetch images given the dataframe of urls and image id's
    Args:
        data: Input dataframe containing image urls and id's
        dir_path: Root directory
        img_path: Path to save images to
        img_ext: Image extension to save as
    """
    tot_url = len(data)

    cnt_downloaded, cnt_exists, cnt_resp_err = 0, 0, 0
    for _, row in tqdm(data.iterrows(), unit="rows", total=tot_url):
        url = row["Image URL"]
        id_ = row["Image Id"]
        file_name = os.path.join(dir_path, img_path, f"{id_}.{img_ext}")
        if not os.path.isfile(file_name):
            try:
                response = requests.get(url)
                if response.status_code == 200:
                    with Image.open(io.BytesIO(response.content)) as img:
                        try:
                            img.convert("RGB").save(file_name)
                        except OSError:
                            img.convert("RGB").save(file_name)
                    cnt_downloaded += 1
                else:
                    cnt_resp_err += 1
            except requests.exceptions.RequestException as e:
                raise SystemExit(e)
        else:
            cnt_exists += 1

    print(f"Images downloaded: {cnt_downloaded}!")
    print(f"Non existent urls: {cnt_resp_err}!")
    print(f"Images already exist: {cnt_exists}!")

## Fix Id's

In [None]:
fetch_data = sketch_df.groupby("Image URL").first().reset_index()
print(f"Unique URL's: {fetch_data.shape}")
fetch_data = fetch_data[~(fetch_data["Image URL"] == "Image not found")]
print(f"Valid unique URL's: {fetch_data.shape}")

In [None]:
print(f"Unique URL's {fetch_data['Image URL'].nunique()}")
print(f"Unique ID's {fetch_data['Image Id'].nunique()}")

In [None]:
mask = fetch_data["Image Id"].duplicated()
fetch_data["Image Id"] = "img_" + fetch_data["Image Id"].astype("str")
fetch_data.loc[mask, "Image Id"] += "_1"

In [None]:
print(f"Unique URL's {fetch_data['Image URL'].nunique()}")
print(f"Unique ID's {fetch_data['Image Id'].nunique()}")

## Download

In [None]:
os.makedirs(os.path.join(root, img_path), exist_ok=True)
fetch_images(fetch_data, root, img_path, "png")

In [None]:
file_list = os.listdir(os.path.join(root, img_path))
print(f"Number of images: {len(file_list)}")

# Create labels

In [None]:
del_cols = ["Image URL", "Image Name", "Product Number", "Garment group", "Department Name", "Seasonold", "UniquieVal"]
data = fetch_data.drop(del_cols, axis=1)
print(data.shape)
data.head()

In [None]:
data.info(verbose=True)

## Fix nulls

In [None]:
data_clean = data.fillna(value=fill_str, axis=0)
data_clean.info(verbose=True)

In [None]:
print(f"Numberof unique types: {data_clean['Type'].nunique()}")
print(f"Numberof unique category: {data_clean['Category'].nunique()}")
print(f"Numberof unique subcategory: {data_clean['SubCategory'].nunique()}")
print(f"Numberof unique customer group: {data_clean['Customer Group'].nunique()}")

## Filter data

In [None]:
id_list = [f[:-4] for f in file_list]
data_filt = data_clean[data_clean["Image Id"].isin(id_list)]
data_filt.shape

## Check labels

In [None]:
type_vc = data_filt["Type"].value_counts()
type_vc.plot.bar(figsize=(15, 8))

In [None]:
cat_vc = data_filt["Category"].value_counts()
cat_vc.plot.bar(figsize=(15, 8))

In [None]:
scat_vc = data_filt["SubCategory"].value_counts()
scat_vc.plot.bar(figsize=(15, 8))

In [None]:
cg_vc = data_filt["Customer Group"].value_counts()
cg_vc.plot.bar(figsize=(15, 8))

## Multi label Multi-class

In [None]:
out = pd.get_dummies(
    data_filt,
    columns=["Type", "Category", "SubCategory", "Customer Group"],
)
print(out.shape)
out.head()

In [None]:
tmp = out.groupby("Image Id").sum()
assert (tmp > 1).any().any() == False
assert (tmp < 0).any().any() == False
assert (tmp == 0).all().any() == False
assert (tmp == 0).all(1).any() == False
assert (tmp.sum(1) > 1).all() == True
assert tmp.reset_index().shape == out.shape

### Split data

In [None]:
out_train, out_val = train_test_split(out, test_size=0.1, random_state=42)
print(f"Training data shape: {out_train.shape}")
print(f"Validation data shape: {out_val.shape}")

### Write labels

In [None]:
out_train.to_csv(os.path.join(root, "train_" + label_file), index=False)
out_val.to_csv(os.path.join(root, "val_" + label_file), index=False)

# Train model

In [None]:
# if needed
# !pip install --upgrade torchvision

In [None]:
root = "/dbfs/sketches"
model_dir = "/experiments/base_model"

In [None]:
!pip install --force-reinstall $root/Sketches-0.0.1-py3-none-any.whl

In [None]:
!python $root/train.py --data_dir $root --model_dir $root$model_dir

In [None]:
%tensorboard --logdir $root$model_dir/runs --port 6009