# Import `garmentiq`, other dependencies, and download data

In [None]:
import garmentiq as giq
from garmentiq.classification.model_definition import CNN3, CNN4, tinyViT
from garmentiq.classification.utils import CachedDataset
import torch.optim as optim
import torch.nn as nn

In [None]:
# Download data
!curl -L -o garmentiq-classification-set-nordstrom-and-myntra.zip \
  https://www.kaggle.com/api/v1/datasets/download/lygitdata/garmentiq-classification-set-nordstrom-and-myntra

# Function `train_test_split`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/train_test_split.py

In [None]:
DATA = giq.classification.train_test_split(
    output_dir="data",
    train_zip_dir="garmentiq-classification-set-nordstrom-and-myntra.zip",
    test_size=0.15,
    verbose=True
)

# Apply pretrained models

In [None]:
# Download CNN-3, CNN-4, and Tiny ViT pretrained models
!mkdir -p pretrained_models

!wget -q -O pretrained_models/cnn_3.pt \
    https://raw.githubusercontent.com/lygitdata/GarmentIQ/refs/heads/gh-pages/application/demo/image-classification/models/cnn_3.pt
!wget -q -O pretrained_models/cnn_4.pt \
    https://raw.githubusercontent.com/lygitdata/GarmentIQ/refs/heads/gh-pages/application/demo/image-classification/models/cnn_4.pt
!wget -q -O pretrained_models/tiny_vit.pt \
    https://raw.githubusercontent.com/lygitdata/GarmentIQ/refs/heads/gh-pages/application/demo/image-classification/models/tiny_vit.pt

## Function `load_data`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/load_data.py

In [None]:
test_images, test_labels, _ = giq.classification.load_data(
    df=DATA["test_metadata"],
    img_dir=DATA["test_images"],
    label_column="garment",
    resize_dim=(120, 184),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

## Start testing models using the test set

Function `test_pytorch_nn` source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/test_pytorch_nn.py

### Test `CNN-3` on the test set

In [None]:
giq.classification.test_pytorch_nn(
    model_path="pretrained_models/cnn_3.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

In [None]:
img_to_test = DATA['test_metadata']['filename'][963]

pred_label, pred_prob = giq.classification.predict(
    model_path="pretrained_models/cnn_3.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    image_path=f"data/test/images/{img_to_test}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(120, 184),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", img_to_test,
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

### Test `CNN-4` on the test set

In [None]:
giq.classification.test_pytorch_nn(
    model_path="pretrained_models/cnn_4.pt",
    model_class=CNN4,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

In [None]:
pred_label, pred_prob = giq.classification.predict(
    model_path="pretrained_models/cnn_4.pt",
    model_class=CNN4,
    model_args={"num_classes": 9},
    image_path=f"data/test/images/{img_to_test}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(120, 184),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", img_to_test,
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

### Test `Tiny ViT` on the test set

In [None]:
giq.classification.test_pytorch_nn(
    model_path="pretrained_models/tiny_vit.pt",
    model_class=tinyViT,
    model_args={"num_classes": 9, "img_size": (120, 184), "patch_size": 6},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

In [None]:
pred_label, pred_prob = giq.classification.predict(
    model_path="pretrained_models/tiny_vit.pt",
    model_class=tinyViT,
    model_args={"num_classes": 9, "img_size": (120, 184), "patch_size": 6},
    image_path=f"data/test/images/{img_to_test}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(120, 184),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", img_to_test,
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

# Training

## Use pre-defined model structure `CNN3`

To see all pre-defined model structures: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/model_definition.py

In [None]:
CNN3

In [None]:
train_images, train_labels, _ = giq.classification.load_data(
    df=DATA["train_metadata"],
    img_dir=DATA["train_images"],
    label_column="garment",
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

### Function `train_pytorch_nn`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/train_pytorch_nn.py

In [None]:
giq.classification.train_pytorch_nn(
    model_class=CNN3,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "metadata_df": DATA["train_metadata"],
        "raw_labels": DATA["train_metadata"]["garment"],
        "cached_images": train_images,
        "cached_labels": train_labels,
    },
    param={
        "optimizer_class": optim.AdamW,
        "optimizer_args": {"lr": 0.001, "weight_decay": 1e-4},
        "n_fold": 2,
        "n_epoch": 5,
        "patience": 2,
        "batch_size": 256,
        "model_save_dir": "cnn3_models",
        "best_model_name": "best_model.pt",
    },
)

## Use a customized model structure `CNN1`

In [None]:
class CNN1(nn.Module):
    def __init__(self, num_classes):
        super(CNN1, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.25),
            nn.AdaptiveAvgPool2d((4, 4)),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
giq.classification.train_pytorch_nn(
    model_class=CNN1,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "metadata_df": DATA["train_metadata"],
        "raw_labels": DATA["train_metadata"]["garment"],
        "cached_images": train_images,
        "cached_labels": train_labels,
    },
    param={
        "optimizer_class": optim.AdamW,
        "optimizer_args": {"lr": 0.001, "weight_decay": 1e-4},
        "n_fold": 2,
        "n_epoch": 5,
        "patience": 2,
        "batch_size": 256,
        "max_workers": 1,
        "model_save_dir": "customized_models",
        "best_model_name": "best_model.pt",
    },
)

# Testing

In [None]:
test_images, test_labels, _ = giq.classification.load_data(
    df=DATA["test_metadata"],
    img_dir=DATA["test_images"],
    label_column="garment",
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

## Test pre-defined model structure `CNN3`

### Function `test_pytorch_nn`

Source code: https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/test_pytorch_nn.py

In [None]:
giq.classification.test_pytorch_nn(
    model_path="cnn3_models/best_model.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

### Function `predict`

https://github.com/lygitdata/GarmentIQ/blob/main/src/garmentiq/classification/predict.py

In [None]:
pred_label, pred_prob = giq.classification.predict(
    model_path="cnn3_models/best_model.pt",
    model_class=CNN3,
    model_args={"num_classes": 9},
    image_path=f"data/test/images/{DATA['test_metadata']['filename'][1000]}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", DATA['test_metadata']['filename'][1000],
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)

## Test customized model structure `CNN1`

In [None]:
giq.classification.test_pytorch_nn(
    model_path="customized_models/best_model.pt",
    model_class=CNN1,
    model_args={"num_classes": 9},
    dataset_class=CachedDataset,
    dataset_args={
        "raw_labels": DATA["test_metadata"]["garment"],
        "cached_images": test_images,
        "cached_labels": test_labels,
    },
    param={"batch_size": 64},
)

In [None]:
pred_label, pred_prob = giq.classification.predict(
    model_path="customized_models/best_model.pt",
    model_class=CNN1,
    model_args={"num_classes": 9},
    image_path=f"data/test/images/{DATA['test_metadata']['filename'][1000]}",
    classes=DATA['test_metadata']['garment'].unique().tolist(),
    resize_dim=(60, 92),
    normalize_mean=[0.8047, 0.7808, 0.7769],
    normalize_std=[0.2957, 0.3077, 0.3081]
)

print(
    "True label: ", DATA['test_metadata']['filename'][1000],
    "\nPredicted label: ", pred_label,
    "\nPredicted Probabilities: ", pred_prob
)