<a href="https://colab.research.google.com/github/lygitdata/GarmentIQ/blob/main/python_api_demo/classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import `garmentiq`, install dependencies, and download data

In [None]:
import sys

!git clone https://github.com/lygitdata/GarmentIQ.git
!pip install -r /content/GarmentIQ/src/requirements.txt -q

sys.path.insert(0, './GarmentIQ/src')

Cloning into 'GarmentIQ'...
remote: Enumerating objects: 1149, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 1149 (delta 21), reused 9 (delta 9), pack-reused 1120 (from 3)[K
Receiving objects: 100% (1149/1149), 255.19 MiB | 18.18 MiB/s, done.
Resolving deltas: 100% (563/563), done.
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m72.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m47.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m108.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import garmentiq as giq
from garmentiq.classification.model_definition import CNN3, CNN4, tinyViT
from garmentiq.classification.utils import CachedDataset
import torch.optim as optim
import torch.nn as nn

In [None]:
# Download data
!curl -L -o /content/garmentiq-classification-set-nordstrom-and-myntra.zip\
  https://www.kaggle.com/api/v1/datasets/download/lygitdata/garmentiq-classification-set-nordstrom-and-myntra

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 1391M  100 1391M    0     0  95.2M      0  0:00:14  0:00:14 --:--:--  115M


# Function `train_test_split`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/train_test_split.py

In [None]:
DATA = giq.classification.train_test_split(
    output_dir="/content/data",
    train_zip_dir="/content/garmentiq-classification-set-nordstrom-and-myntra.zip",
    test_size=0.15
)

Extracting: 100%|██████████| 23267/23267 [00:18<00:00, 1225.12it/s]




Splitting train data into train/test sets...

All filenames in /content/data/train/images match the metadata.


All filenames in /content/data/test/images match the metadata.



# Training

## Use pre-defined model structure `CNN3`

To see all pre-defined model structures: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/model_definition.py

In [None]:
CNN3

### Function `load_data`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/load_data.py

In [None]:
train_images, train_labels, _ = giq.classification.load_data(
    df=DATA["train_metadata"],
    img_dir=DATA["train_images"],
    label_column="garment",
    resize_dim=(60,  92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

Loading data into memory:   0%|          | 0/19777 [00:00<?, ?it/s]

### Function `train_pytorch_nn`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/train_pytorch_nn.py

In [None]:
giq.classification.train_pytorch_nn(
    model_class=CNN3,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "metadata_df": DATA["train_metadata"],
        "raw_labels": DATA["train_metadata"]["garment"],
        "cached_images": train_images,
        "cached_labels": train_labels,
    },
    param={
        "optimizer_class": optim.AdamW,
        "optimizer_args": {"lr": 0.001, "weight_decay": 1e-4},
        "n_fold": 2,
        "n_epoch": 5,
        "patience": 2,
        "batch_size": 256,
        "model_save_dir": "/content/cnn3_models",
        "best_model_name": "best_model.pt",
    },
)

## Use a customized model structure `CNN1`

In [None]:
class CNN1(nn.Module):
    def __init__(self, num_classes):
        super(CNN1, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.25),
            nn.AdaptiveAvgPool2d((4, 4)),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
giq.classification.train_pytorch_nn(
    model_class=CNN1,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "metadata_df": DATA["train_metadata"],
        "raw_labels": DATA["train_metadata"]["garment"],
        "cached_images": train_images,
        "cached_labels": train_labels,
    },
    param={
        "optimizer_class": optim.AdamW,
        "optimizer_args": {"lr": 0.001, "weight_decay": 1e-4},
        "n_fold": 2,
        "n_epoch": 5,
        "patience": 2,
        "batch_size": 256,
        "model_save_dir": "/content/customized_models",
        "best_model_name": "best_model.pt",
    },
)

# Testing

In [None]:
test_images, test_labels, _ = giq.classification.load_data(
    df=DATA["test_metadata"],
    img_dir=DATA["test_images"],
    label_column="garment",
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

Loading data into memory:   0%|          | 0/3489 [00:00<?, ?it/s]

## Test pre-defined model structure `CNN3`

### Function `test_pytorch_nn`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/test_pytorch_nn.py

In [None]:
giq.classification.test_pytorch_nn(
    model_path="/content/cnn3_models/best_model.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

Evaluating:   0%|          | 0/55 [00:00<?, ?it/s]

Test Loss: 0.2930
Test Accuracy: 0.8968
Test F1 Score: 0.8968

Classification Report:
                    precision    recall  f1-score   support

 long sleeve dress       0.87      0.81      0.84       384
   long sleeve top       0.98      0.87      0.92       442
short sleeve dress       0.78      0.87      0.82       382
  short sleeve top       0.87      0.98      0.92       523
            shorts       0.96      0.94      0.95       485
             skirt       0.87      0.84      0.86       281
          trousers       0.95      0.97      0.96       320
              vest       0.93      0.83      0.88       230
        vest dress       0.89      0.88      0.88       442

          accuracy                           0.90      3489
         macro avg       0.90      0.89      0.89      3489
      weighted avg       0.90      0.90      0.90      3489



### Function `predict`

https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/predict.py

In [None]:
pred_label, pred_prob = giq.classification.predict(
    model_path="/content/cnn3_models/best_model.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    image_path=f"/content/data/test/images/{DATA['test_metadata']['filename'][1000]}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", DATA['test_metadata']['filename'][1000],
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

True label:  trousers_347.jpg 
Predicted label:  trousers 
Predicted Probabilities:  [4.082682607986499e-06, 7.1354088504449464e-06, 6.2376966525334865e-06, 8.394850738113746e-06, 0.0012580862967297435, 5.4195937991607934e-05, 0.9985495209693909, 0.00010102804662892595, 1.1451444152044132e-05]


## Test customized model structure `CNN1`

In [None]:
giq.classification.test_pytorch_nn(
    model_path="/content/customized_models/best_model.pt",
    model_class=CNN1,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

Evaluating:   0%|          | 0/55 [00:00<?, ?it/s]

Test Loss: 0.6809
Test Accuracy: 0.7423
Test F1 Score: 0.7236

Classification Report:
                    precision    recall  f1-score   support

 long sleeve dress       0.64      0.52      0.58       384
   long sleeve top       0.83      0.92      0.87       442
short sleeve dress       0.50      0.13      0.20       382
  short sleeve top       0.88      0.88      0.88       523
            shorts       0.85      0.89      0.87       485
             skirt       0.73      0.69      0.71       281
          trousers       0.92      0.87      0.90       320
              vest       0.92      0.74      0.82       230
        vest dress       0.50      0.90      0.64       442

          accuracy                           0.74      3489
         macro avg       0.75      0.73      0.72      3489
      weighted avg       0.75      0.74      0.72      3489



In [None]:
pred_label, pred_prob = giq.classification.predict(
    model_path="/content/customized_models/best_model.pt",
    model_class=CNN1,
    model_args={"num_classes": 9},
    image_path=f"/content/data/test/images/{DATA['test_metadata']['filename'][1000]}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", DATA['test_metadata']['filename'][1000],
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

True label:  trousers_347.jpg 
Predicted label:  trousers 
Predicted Probabilities:  [2.1105310224811547e-05, 1.2697221791313495e-05, 5.83124619879527e-06, 3.6388353237271076e-06, 0.33395445346832275, 0.0075272442772984505, 0.6571474671363831, 0.0013236853992566466, 4.0134750634024385e-06]
