<a href="https://colab.research.google.com/github/lygitdata/GarmentIQ/blob/main/python_api_demo/classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import `garmentiq`, install dependencies, and download data

In [1]:
import sys

!git clone https://github.com/lygitdata/GarmentIQ.git
!pip install -r /content/GarmentIQ/src/requirements.txt -q

sys.path.insert(0, './GarmentIQ/src')

Cloning into 'GarmentIQ'...
remote: Enumerating objects: 1881, done.[K
remote: Counting objects: 100% (263/263), done.[K
remote: Compressing objects: 100% (237/237), done.[K
remote: Total 1881 (delta 120), reused 83 (delta 25), pack-reused 1618 (from 2)[K
Receiving objects: 100% (1881/1881), 263.47 MiB | 24.03 MiB/s, done.
Resolving deltas: 100% (939/939), done.
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m78.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m74.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m71.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m36.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import garmentiq as giq
from garmentiq.classification.model_definition import CNN3, CNN4, tinyViT
from garmentiq.classification.utils import CachedDataset
import torch.optim as optim
import torch.nn as nn

In [3]:
# Download data
!curl -L -o /content/garmentiq-classification-set-nordstrom-and-myntra.zip\
  https://www.kaggle.com/api/v1/datasets/download/lygitdata/garmentiq-classification-set-nordstrom-and-myntra

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 1391M  100 1391M    0     0  34.1M      0  0:00:40  0:00:40 --:--:-- 34.1M


# Function `train_test_split`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/train_test_split.py

In [4]:
DATA = giq.classification.train_test_split(
    output_dir="/content/data",
    train_zip_dir="/content/garmentiq-classification-set-nordstrom-and-myntra.zip",
    test_size=0.15
)

Extracting:   0%|          | 0/23267 [00:00<?, ?it/s]



Splitting train data into train/test sets...

All filenames in /content/data/train/images match the metadata.


All filenames in /content/data/test/images match the metadata.



# Apply pretrained models

In [5]:
# Download CNN-3, CNN-4, and Tiny ViT pretrained models
!mkdir -p /content/pretrained_models

!wget -q -O /content/pretrained_models/cnn_3.pt https://raw.githubusercontent.com/lygitdata/GarmentIQ/refs/heads/gh-pages/application/demo/image-classification/models/cnn_3.pt
!wget -q -O /content/pretrained_models/cnn_4.pt https://raw.githubusercontent.com/lygitdata/GarmentIQ/refs/heads/gh-pages/application/demo/image-classification/models/cnn_4.pt
!wget -q -O /content/pretrained_models/tiny_vit.pt https://raw.githubusercontent.com/lygitdata/GarmentIQ/refs/heads/gh-pages/application/demo/image-classification/models/tiny_vit.pt

## Function `load_data`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/load_data.py

In [6]:
test_images, test_labels, _ = giq.classification.load_data(
    df=DATA["test_metadata"],
    img_dir=DATA["test_images"],
    label_column="garment",
    resize_dim=(120, 184),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

Loading data into memory:   0%|          | 0/3489 [00:00<?, ?it/s]

## Start testing models using the test set

Function `test_pytorch_nn` source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/test_pytorch_nn.py

### Test `CNN-3` on the test set

In [7]:
giq.classification.test_pytorch_nn(
    model_path="/content/pretrained_models/cnn_3.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

Evaluating:   0%|          | 0/55 [00:00<?, ?it/s]

Test Loss: 0.1678
Test Accuracy: 0.9458
Test F1 Score: 0.9459

Classification Report:
                    precision    recall  f1-score   support

 long sleeve dress       0.93      0.89      0.91       384
   long sleeve top       0.99      0.97      0.98       442
short sleeve dress       0.85      0.93      0.89       382
  short sleeve top       0.97      0.98      0.98       523
            shorts       0.97      0.98      0.98       485
             skirt       0.92      0.91      0.92       281
          trousers       0.98      0.98      0.98       320
              vest       0.94      0.93      0.93       230
        vest dress       0.94      0.90      0.92       442

          accuracy                           0.95      3489
         macro avg       0.94      0.94      0.94      3489
      weighted avg       0.95      0.95      0.95      3489



In [8]:
img_to_test = DATA['test_metadata']['filename'][963]

pred_label, pred_prob = giq.classification.predict(
    model_path="/content/pretrained_models/cnn_3.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    image_path=f"/content/data/test/images/{img_to_test}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(120, 184),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", img_to_test,
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

True label:  long_sleeve_top_3201.jpg 
Predicted label:  long sleeve top 
Predicted Probabilities:  [3.2848849514266476e-05, 0.9988889098167419, 3.692270547617227e-05, 7.212285709101707e-05, 0.00017522269627079368, 6.496862624771893e-05, 0.0001241860882146284, 0.0005721973720937967, 3.258765354985371e-05]


### Test `CNN-4` on the test set

In [9]:
giq.classification.test_pytorch_nn(
    model_path="/content/pretrained_models/cnn_4.pt",
    model_class=CNN4,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

Evaluating:   0%|          | 0/55 [00:00<?, ?it/s]

Test Loss: 0.1611
Test Accuracy: 0.9533
Test F1 Score: 0.9533

Classification Report:
                    precision    recall  f1-score   support

 long sleeve dress       0.94      0.88      0.91       384
   long sleeve top       0.99      0.98      0.99       442
short sleeve dress       0.87      0.95      0.91       382
  short sleeve top       0.98      0.99      0.99       523
            shorts       0.98      0.99      0.98       485
             skirt       0.93      0.91      0.92       281
          trousers       0.97      0.99      0.98       320
              vest       0.97      0.94      0.96       230
        vest dress       0.94      0.92      0.93       442

          accuracy                           0.95      3489
         macro avg       0.95      0.95      0.95      3489
      weighted avg       0.95      0.95      0.95      3489



In [10]:
pred_label, pred_prob = giq.classification.predict(
    model_path="/content/pretrained_models/cnn_4.pt",
    model_class=CNN4,
    model_args={"num_classes": 9},
    image_path=f"/content/data/test/images/{img_to_test}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(120, 184),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", img_to_test,
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

True label:  long_sleeve_top_3201.jpg 
Predicted label:  long sleeve top 
Predicted Probabilities:  [6.615861821046565e-06, 0.99965500831604, 2.618196958792396e-05, 8.517393143847585e-05, 1.625174809305463e-05, 1.419600721419556e-05, 9.749750461196527e-06, 8.316439198097214e-05, 0.00010364743502577767]


### Test `Tiny ViT` on the test set

In [11]:
giq.classification.test_pytorch_nn(
    model_path="/content/pretrained_models/tiny_vit.pt",
    model_class=tinyViT,
    model_args={"num_classes": 9, "img_size": (120, 184), "patch_size": 6},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

Evaluating:   0%|          | 0/55 [00:00<?, ?it/s]

Test Loss: 0.1347
Test Accuracy: 0.9576
Test F1 Score: 0.9576

Classification Report:
                    precision    recall  f1-score   support

 long sleeve dress       0.94      0.92      0.93       384
   long sleeve top       0.98      1.00      0.99       442
short sleeve dress       0.89      0.91      0.90       382
  short sleeve top       0.99      0.99      0.99       523
            shorts       0.99      0.99      0.99       485
             skirt       0.95      0.93      0.94       281
          trousers       0.99      0.98      0.99       320
              vest       0.92      0.95      0.94       230
        vest dress       0.94      0.92      0.93       442

          accuracy                           0.96      3489
         macro avg       0.95      0.95      0.95      3489
      weighted avg       0.96      0.96      0.96      3489



In [12]:
pred_label, pred_prob = giq.classification.predict(
    model_path="/content/pretrained_models/tiny_vit.pt",
    model_class=tinyViT,
    model_args={"num_classes": 9, "img_size": (120, 184), "patch_size": 6},
    image_path=f"/content/data/test/images/{img_to_test}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(120, 184),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", img_to_test,
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

True label:  long_sleeve_top_3201.jpg 
Predicted label:  long sleeve top 
Predicted Probabilities:  [4.32005390393897e-06, 0.9999654293060303, 4.1738971390259394e-07, 1.2229177627887111e-05, 4.470796739042271e-06, 4.943235580867622e-06, 5.885341352040996e-07, 7.537031251558801e-06, 1.1389132481554043e-07]


# Training

## Use pre-defined model structure `CNN3`

To see all pre-defined model structures: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/model_definition.py

In [13]:
CNN3

In [14]:
train_images, train_labels, _ = giq.classification.load_data(
    df=DATA["train_metadata"],
    img_dir=DATA["train_images"],
    label_column="garment",
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

Loading data into memory:   0%|          | 0/19777 [00:00<?, ?it/s]

### Function `train_pytorch_nn`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/train_pytorch_nn.py

In [15]:
giq.classification.train_pytorch_nn(
    model_class=CNN3,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "metadata_df": DATA["train_metadata"],
        "raw_labels": DATA["train_metadata"]["garment"],
        "cached_images": train_images,
        "cached_labels": train_labels,
    },
    param={
        "optimizer_class": optim.AdamW,
        "optimizer_args": {"lr": 0.001, "weight_decay": 1e-4},
        "n_fold": 2,
        "n_epoch": 5,
        "patience": 2,
        "batch_size": 256,
        "model_save_dir": "/content/cnn3_models",
        "best_model_name": "best_model.pt",
    },
)


Fold 1/2


Total Progress:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 1 | Val Loss: 0.6258 | F1: 0.7861 | Acc: 0.7891


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 2 | Val Loss: 0.5848 | F1: 0.7863 | Acc: 0.7919


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 3 | Val Loss: 0.3982 | F1: 0.8603 | Acc: 0.8623


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 4 | Val Loss: 0.3310 | F1: 0.8874 | Acc: 0.8865


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 5 | Val Loss: 0.5671 | F1: 0.8027 | Acc: 0.8066

Fold 2/2


Total Progress:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 1 | Val Loss: 0.6322 | F1: 0.7585 | Acc: 0.7695


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 2 | Val Loss: 0.5661 | F1: 0.8061 | Acc: 0.8095


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 3 | Val Loss: 0.3430 | F1: 0.8862 | Acc: 0.8863


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 4 | Val Loss: 0.3409 | F1: 0.8850 | Acc: 0.8851


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 5 | Val Loss: 0.2969 | F1: 0.9011 | Acc: 0.9014

Training completed. Best model saved at: /content/cnn3_models/best_model.pt


## Use a customized model structure `CNN1`

In [16]:
class CNN1(nn.Module):
    def __init__(self, num_classes):
        super(CNN1, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.25),
            nn.AdaptiveAvgPool2d((4, 4)),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [17]:
giq.classification.train_pytorch_nn(
    model_class=CNN1,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "metadata_df": DATA["train_metadata"],
        "raw_labels": DATA["train_metadata"]["garment"],
        "cached_images": train_images,
        "cached_labels": train_labels,
    },
    param={
        "optimizer_class": optim.AdamW,
        "optimizer_args": {"lr": 0.001, "weight_decay": 1e-4},
        "n_fold": 2,
        "n_epoch": 5,
        "patience": 2,
        "batch_size": 256,
        "max_workers": 1,
        "model_save_dir": "/content/customized_models",
        "best_model_name": "best_model.pt",
    },
)


Fold 1/2


Total Progress:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 1 | Val Loss: 1.1905 | F1: 0.5785 | Acc: 0.6016


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 2 | Val Loss: 0.8725 | F1: 0.6722 | Acc: 0.6855


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 3 | Val Loss: 0.7899 | F1: 0.7024 | Acc: 0.7130


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 4 | Val Loss: 0.7143 | F1: 0.7424 | Acc: 0.7470


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 1 | Epoch 5 | Val Loss: 0.6933 | F1: 0.7562 | Acc: 0.7586

Fold 2/2


Total Progress:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 1 | Val Loss: 1.2084 | F1: 0.5686 | Acc: 0.5977


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 2 | Val Loss: 0.9410 | F1: 0.6075 | Acc: 0.6462


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 3 | Val Loss: 0.7983 | F1: 0.6972 | Acc: 0.7134


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 4 | Val Loss: 0.7168 | F1: 0.7383 | Acc: 0.7439


Training:   0%|          | 0/39 [00:00<?, ?it/s]

Validation:   0%|          | 0/39 [00:00<?, ?it/s]

Fold 2 | Epoch 5 | Val Loss: 0.6958 | F1: 0.7480 | Acc: 0.7520

Training completed. Best model saved at: /content/customized_models/best_model.pt


# Testing

In [18]:
test_images, test_labels, _ = giq.classification.load_data(
    df=DATA["test_metadata"],
    img_dir=DATA["test_images"],
    label_column="garment",
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

Loading data into memory:   0%|          | 0/3489 [00:00<?, ?it/s]

## Test pre-defined model structure `CNN3`

### Function `test_pytorch_nn`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/test_pytorch_nn.py

In [19]:
giq.classification.test_pytorch_nn(
    model_path="/content/cnn3_models/best_model.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

Evaluating:   0%|          | 0/55 [00:00<?, ?it/s]

Test Loss: 0.2941
Test Accuracy: 0.8980
Test F1 Score: 0.8973

Classification Report:
                    precision    recall  f1-score   support

 long sleeve dress       0.80      0.91      0.85       384
   long sleeve top       0.93      0.96      0.94       442
short sleeve dress       0.81      0.86      0.84       382
  short sleeve top       0.95      0.93      0.94       523
            shorts       0.89      0.97      0.93       485
             skirt       0.90      0.73      0.81       281
          trousers       0.95      0.98      0.97       320
              vest       0.95      0.84      0.89       230
        vest dress       0.93      0.82      0.87       442

          accuracy                           0.90      3489
         macro avg       0.90      0.89      0.89      3489
      weighted avg       0.90      0.90      0.90      3489



### Function `predict`

https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/predict.py

In [20]:
pred_label, pred_prob = giq.classification.predict(
    model_path="/content/cnn3_models/best_model.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    image_path=f"/content/data/test/images/{DATA['test_metadata']['filename'][1000]}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", DATA['test_metadata']['filename'][1000],
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

True label:  trousers_347.jpg 
Predicted label:  trousers 
Predicted Probabilities:  [3.2946604733297136e-06, 3.7289589727151906e-06, 1.1065539183618966e-05, 7.930570063763298e-06, 0.0008685069042257965, 2.6890760636888444e-05, 0.9990187883377075, 4.816617729375139e-05, 1.1670107596728485e-05]


## Test customized model structure `CNN1`

In [21]:
giq.classification.test_pytorch_nn(
    model_path="/content/customized_models/best_model.pt",
    model_class=CNN1,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

Evaluating:   0%|          | 0/55 [00:00<?, ?it/s]

Test Loss: 0.6944
Test Accuracy: 0.7506
Test F1 Score: 0.7501

Classification Report:
                    precision    recall  f1-score   support

 long sleeve dress       0.66      0.54      0.59       384
   long sleeve top       0.90      0.86      0.88       442
short sleeve dress       0.49      0.44      0.46       382
  short sleeve top       0.88      0.89      0.89       523
            shorts       0.90      0.87      0.88       485
             skirt       0.76      0.60      0.67       281
          trousers       0.79      0.89      0.84       320
              vest       0.92      0.73      0.81       230
        vest dress       0.55      0.80      0.65       442

          accuracy                           0.75      3489
         macro avg       0.76      0.74      0.74      3489
      weighted avg       0.76      0.75      0.75      3489



In [22]:
pred_label, pred_prob = giq.classification.predict(
    model_path="/content/customized_models/best_model.pt",
    model_class=CNN1,
    model_args={"num_classes": 9},
    image_path=f"/content/data/test/images/{DATA['test_metadata']['filename'][1000]}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", DATA['test_metadata']['filename'][1000],
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

True label:  trousers_347.jpg 
Predicted label:  trousers 
Predicted Probabilities:  [6.219950591912493e-06, 3.258124479543767e-06, 3.1806232527742395e-06, 1.4495129789793282e-06, 0.04034010320901871, 0.00028508342802524567, 0.9592325687408447, 0.00012734273332171142, 7.457717856595991e-07]
