In [1]:
import torch
from torch.utils.data import DataLoader
from models.resnet_clothing_model import ClothingClassifier
from data.dataset import GarmentDataset
from utils.loss import compute_loss
from utils.metrics import accuracy
from data.transformation import CustomResNetTransform
from tqdm import tqdm
import matplotlib.pyplot as plt 
import pandas as pd
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
"""
DATA SET UP
"""
df = pd.read_csv('dataset/list_category_img.txt', delim_whitespace=True, skiprows=1)
selected_labels = [3, 6, 11, 16, 17, 18, 19, 26, 32, 33, 41]
selected_names = ["Blouse", "Cardigan", "Jacket", "Sweater", "Tank", "Tee", "Top", "Jeans", "Shorts", "Skirts", "Dress"]
df = df[df['category_label'].isin(selected_labels)]
df = df.reset_index(drop=True)


label_mapping = {original: new for new, original in enumerate(selected_labels)}
df['category_label'] = df['category_label'].map(label_mapping)

train_val_df, test_df = train_test_split(
    df,
    test_size=0.2,
    random_state=1331,
    stratify=df['category_label']
)
train_df, val_df = train_test_split(
    train_val_df,
    test_size=0.25,
    random_state=1331,
    stratify=train_val_df['category_label']
)

balanced_train_df = train_df.groupby('category_label', group_keys=False).apply(
    lambda x: x.sample(n=1800, random_state=1331)
).reset_index(drop=True)

train_dataset = GarmentDataset(train_df, transform=CustomResNetTransform())
balanced_dataset = GarmentDataset(balanced_train_df, transform=CustomResNetTransform())
val_dataset = GarmentDataset(val_df, transform=CustomResNetTransform())
test_dataset = GarmentDataset(test_df, transform=CustomResNetTransform())

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)
balanced_loader = DataLoader(balanced_dataset, batch_size=64, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers=8)

# Get class weights for loss function
class_sample_counts = train_df['category_label'].value_counts().sort_index().tolist()
class_counts_tensor = torch.tensor(class_sample_counts, dtype=torch.float)
class_weights = 1.0 / class_counts_tensor
class_weights = class_weights / class_weights.mean()

  df = pd.read_csv('dataset/list_category_img.txt', delim_whitespace=True, skiprows=1)
  balanced_train_df = train_df.groupby('category_label', group_keys=False).apply(


In [9]:
"""
TRAIN CODE
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

NUM_CLASSES = 11
FROZEN_LAYERS = 60
LEARNING_RATE = 0.001
NUM_EPOCHS = 22

model = ClothingClassifier(num_classes=NUM_CLASSES, num_frozen_resnet_layers=FROZEN_LAYERS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=3, verbose=True
)


epoch_losses = []
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    total_correct = 0
    total_samples = 0

    # 3) Training loop
    with tqdm(train_loader, desc=f"Epoch Train {epoch+1}/{NUM_EPOCHS}", unit="batch") as pbar:
        for images, category_id in pbar:
            images = images.to(device)
            category_id = category_id.to(device)


            optimizer.zero_grad()
            preds = model(images)
            loss = compute_loss(preds, {'category_id': category_id}, class_weights=class_weights.to(device))

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            batch_correct = accuracy(preds, category_id)
            total_correct += batch_correct
            total_samples += images.size(0)

            pbar.set_postfix(loss=loss.item())

    # 4) Check against a validator to adjust learning rate
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        with tqdm(val_loader, desc=f"Epoch Val {epoch+1}/{NUM_EPOCHS}", unit="batch") as pbar:
            for images, category_id in pbar:
                images = images.to(device)
                category_id = category_id.to(device)
                preds = model(images)
                loss = compute_loss(preds, {'category_id': category_id})
                val_loss += loss.item()
    val_loss /= len(val_loader)
    scheduler.step(val_loss)
    epoch_loss = running_loss / len(train_loader)
    epoch_losses.append(epoch_loss)
    epoch_accuracy = (total_correct / total_samples) * 100
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: Accuracy = {epoch_accuracy:.2f}%, Training Loss = {epoch_loss:.4f}, Validation Loss = {val_loss:.4f}")
print("Training complete. Saving model...")
torch.save(model.state_dict(), 'clothing_classifier.pth')
print("Model saved.")


plt.subplot(1, 1, 1)
plt.plot(epoch_losses, marker='o', color='red', label='Epoch Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Across Epochs')
plt.legend()

plt.tight_layout()
plt.show()

Using device: cpu


Epoch Train 1/22:   0%|          | 1/2227 [00:17<10:45:58, 17.41s/batch, loss=2.43]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x125e15300>
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/

KeyboardInterrupt: 