# AI6126 Project 1 Fashion Attributes Classification Challenges

Lu Cheuk Fung Jeff G2304245F clu014@e.ntu.edu.sg

In [97]:
import timm
import pandas as pd
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import time
import numpy as np
from torchsummary import summary
from transformers import AutoModel, SwinModel

### The box below will control everything

In [98]:
TRAINING = False # True: training
EVALUATE = True # True: calculate the validation loss and accuracy
PRINTING = True # True: print the 'prediction.txt'
MODEL_NAME = "test" # The checkpoint file name if you are training a model
LOAD_CHECKPOINT = 'Swin_large_68_cont_best.pth' # This will be the checkpoint path used if you are evaluation or printing
IMG_DIR = os.path.join("FashionDataset", "FashionDataset","img")# The directory of images of the dataset, default "FashionDataset/FashionDataset/img" 
SPLIT_DIR = os.path.join("FashionDataset", "FashionDataset","split")  # The directory of the labels of the dataset, default "FashionDataset/FashionDataset/split"

In [99]:
BATCH_SIZE = 16
learning_rate = 1e-5 #0.001 1e-4 1e-5(good) 5e-6(best)
momentum = (0.9, 0.999)
wd = 1e-5 # 1e-5(good) 5e-4(best)
EPOCHS = 50
retrain = False
SWIN = True # 1st
XCEPTION = False # 2nd
ONEHOT = True # label = 26 if True
CROSS = False # cross entropy if True, bcewithlogitsloss otherwise
swin_model = ['swin-tiny-patch4-window7-224', 'swin-large-patch4-window7-224', 'swin-base-patch4-window7-224', 'swinv2_base']
swin_hidden_size = [768, 1536, 1024, 1536]
swinmodel = 1
IMAGE_SIZE = 224 if SWIN and swinmodel !=3 else 256 if SWIN else 299
use_weight = False # use bce weight
checkpoint_name = MODEL_NAME
drop_ratio = 0.35 #0.35
two_layer = False
batch_norm = False
MIXUP = True
lr_scheduler = False
early_stopping = 7
mean = [0.7657, 0.7359, 0.7254]
std = [0.2838, 0.2965, 0.3026]

retrain_checkpoint = 'xception_checkpoint_ep7.pth'
# bce_weight = torch.tensor([5.51041659e+00, 4.86166466e+00, 1.39700595e+01, 1.08483410e+01,
#        4.48715555e+01, 1.07641196e+00, 4.61698070e+01, 2.28299408e+00,
#        4.80046398e+00, 9.12045886e-01, 9.24590145e+00, 4.25762351e+00,
#        4.04099971e-01, 1.39693192e+00, 5.05326870e+00, 2.62157757e+02,
#        1.41662638e+00, 1.51290318e+01, 6.12250704e+00, 4.71020887e-01,
#        4.66190432e+01, 1.03166645e+02, 1.04678897e+01, 5.67556735e+00,
#        1.69211463e+01, 2.58811681e-01], dtype=torch.float32)#torch.ones(26) #

bce_weight = torch.tensor([2.32820056, 2.27380113, 2.73221439, 2.62237946, 3.23898725,
       1.61899464, 3.25137419, 1.94552092, 2.26829935, 1.54703282,
       2.55296539, 2.21618339, 1.19350495, 1.73219137, 2.29058852,
       4.00557884, 1.73827145, 2.76682727, 2.37394542, 1.2600563,
       3.25557949, 3.60055544, 2.60687527, 2.34102541, 2.81544591,
       1.        ], dtype=torch.float32)

In [100]:
train_csv = pd.read_csv(os.path.join(SPLIT_DIR, "train.txt"), names=["x", ])
val_csv = pd.read_csv(os.path.join(SPLIT_DIR, "val.txt"), names=["x",])
test_csv = pd.read_csv(os.path.join(SPLIT_DIR, "test.txt"), names=["X",])
train_attr_csv = pd.read_csv(os.path.join(SPLIT_DIR, "train_attr.txt"), delimiter=" ", names=['c1', 'c2', 'c3', 'c4', 'c5', 'c6'])
val_attr_csv = pd.read_csv(os.path.join(SPLIT_DIR, "val_attr.txt"), delimiter=' ', names=['c1', 'c2', 'c3', 'c4', 'c5', 'c6'])

In [101]:
train_csv.iloc[2, 0].split('/')[1]

'00002.jpg'

In [102]:
train_attr_csv.iloc[0].tolist()

[5, 0, 2, 0, 2, 2]

In [103]:
def vec_to_onehot(vec, flatten):
    if flatten:
        onehot = torch.zeros(26, dtype=torch.float32)
        onehot[vec[0]] = 1
        onehot[vec[1]+7] = 1
        onehot[vec[2]+10] = 1
        onehot[vec[3]+13] = 1
        onehot[vec[4]+17] = 1
        onehot[vec[5]+23] = 1
    else:
        onehot = torch.tensor(vec, dtype=torch.float32)
    return onehot

def select_attr(Mat, flatten):
    if flatten:
        vec = torch.zeros(Mat.shape[0], 6)
        vec[:, 0] = torch.argmin(torch.abs(1 - Mat[:, :7]), dim=1)
        vec[:, 1] = torch.argmin(torch.abs(1 - Mat[:, 7:10]), dim=1)
        vec[:, 2] = torch.argmin(torch.abs(1 - Mat[:, 10:13]), dim=1)
        vec[:, 3] = torch.argmin(torch.abs(1 - Mat[:, 13:17]), dim=1)
        vec[:, 4] = torch.argmin(torch.abs(1 - Mat[:, 17:23]), dim=1)
        vec[:, 5] = torch.argmin(torch.abs(1 - Mat[:, 23:]), dim=1)
    else:
        vec = torch.round(Mat)
    return vec
    

In [104]:
class FashionDataset(Dataset):
    def __init__(self, img_dir, X, y, transform=None):
        self.img_dir = img_dir
        self.img_df = X
        self.labels_df = y
        self.transform = transform

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.img_df.iloc[idx, 0].split('/')[1])
        image = Image.open(img_name).convert('RGB')
        labels = vec_to_onehot(self.labels_df.iloc[idx].tolist(), ONEHOT) # torch.tensor(self.labels_df.iloc[idx].tolist(), dtype=torch.int32)

        if self.transform:
            image = self.transform(image)

        return image, labels

In [105]:
def get_train_valid_loader(batch_size, seed=1):
    training_transformation = transforms.Compose([
        transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(30),  # Rotate the image by a random angle between -10 and 10 degrees
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Randomly change brightness, contrast, saturation, and hue
        # transforms.RandomAffine(degrees=20, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=10), # last added
        # transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.5),
        transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0), ratio=(0.9, 1.1)),  # Randomly crop and resize the image
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
        # transforms.Normalize(training_set_mean, training_set_std)
    ])
    valid_transformation = transforms.Compose([
        transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
        # transforms.Normalize(training_set_mean, training_set_std)
    ])
    test_transformation = transforms.Compose([
        transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
        # transforms.Normalize(training_set_mean, training_set_std)
    ])
    train_set = FashionDataset(IMG_DIR, train_csv, train_attr_csv, transform=training_transformation)
    validation_set = FashionDataset(IMG_DIR, val_csv, val_attr_csv, transform=valid_transformation)
    test_set = FashionDataset(IMG_DIR, test_csv, val_attr_csv, transform=test_transformation)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, valid_loader, test_loader

In [106]:
train_loader, valid_loader, test_loader = get_train_valid_loader(BATCH_SIZE)

In [107]:
if retrain:
    print("retrain")
    model = timm.create_model('xception41', pretrained=False, num_classes=(26 if ONEHOT else 6))
    model.cuda()
    if ONEHOT:
        if not CROSS:
            criterion = torch.nn.BCEWithLogitsLoss(weight=(bce_weight if use_weight else None))#weight=bce_weight.cuda()
        else:
            print("Cross Entropy")
            criterion = torch.nn.CrossEntropyLoss()
    else:
        criterion = nn.MSELoss().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=momentum, weight_decay=wd)

    # Load model checkpoint
    checkpoint = torch.load(retrain_checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # loss = checkpoint['loss']
    print("loaded pretrained weight")
else:
    if not SWIN:
        # model = timm.create_model("hf_hub:timm/xception41.tf_in1k", pretrained=True, num_classes=(26 if ONEHOT else 6))
        if XCEPTION:
            pretrained_cfg_overlay = {'file' : r"./model/xception71.tf_in1k/pytorch_model.bin"}
            model = timm.models.create_model('xception71.tf_in1k', pretrained=True, pretrained_cfg_overlay=pretrained_cfg_overlay, num_classes=(26 if ONEHOT else 6))
        else:
            pretrained_cfg_overlay = {'file' : r"./model/inceptionv3/pytorch_model.bin"}
            model = timm.models.create_model('inception_v3.tf_in1k', pretrained=True, pretrained_cfg_overlay=pretrained_cfg_overlay, num_classes=(26 if ONEHOT else 6))
        # model = timm.create_model("xception41.tf_in1k", pretrained=True, num_classes=(26 if ONEHOT else 6))
    else:
        swin = SwinModel.from_pretrained(swin_model[swinmodel], local_files_only=True)
    print("loaded pretrained model")

loaded pretrained model


In [108]:
class MultilabelSwin(nn.Module):
    def __init__(self, swin):
        super(MultilabelSwin, self).__init__()
        # self.num_classes = 26
        self.transformer = swin
        self.dropout = nn.Dropout(p=drop_ratio)
        hidden_channel = 300
        if batch_norm:
            self.bn2 = nn.BatchNorm1d(swin_hidden_size[swinmodel])
        if two_layer:
            self.fc2 = nn.Linear(swin_hidden_size[swinmodel], hidden_channel)
            self.bn1 = nn.BatchNorm1d(hidden_channel)
            self.relu = nn.ReLU()
            self.fc1 = nn.Linear(hidden_channel, 26 if ONEHOT else 6)
            self.dropout2 = nn.Dropout(p=drop_ratio)
        else:
            self.fc1 = nn.Linear(swin_hidden_size[swinmodel], 26 if ONEHOT else 6)

    def forward(self, x):
        x = self.transformer(x)
        x = x.pooler_output
        if batch_norm:
            x = self.bn2(x)
        x = self.dropout(x)
        if two_layer:
            x = self.fc2(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.dropout2(x)
        x = self.fc1(x)
        return x

In [109]:
if SWIN:
    model = MultilabelSwin(swin)
    
# Load model checkpoint
if not TRAINING:
    checkpoint = torch.load(LOAD_CHECKPOINT)
    model.load_state_dict(checkpoint['model_state_dict'])

In [110]:
if not retrain:
    print("loaded to cuda")
    model.cuda()
    if ONEHOT:
        if not CROSS:
            print("BCEWithLogitsLoss")
            criterion = torch.nn.BCEWithLogitsLoss(weight=(bce_weight if use_weight else None)).cuda()
            val_criterion = torch.nn.BCEWithLogitsLoss().cuda()
        else:
            print("cross entropy")
            criterion = nn.CrossEntropyLoss().cuda()
            val_criterion = nn.CrossEntropyLoss().cuda()
    else:
        print("MSELoss")
        criterion = nn.MSELoss().cuda()
        val_criterion = nn.MSELoss().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=momentum, weight_decay=wd)
else:
    print("nothing done")

loaded to cuda
BCEWithLogitsLoss


In [111]:
def multi_label_accuracy(y_pred, y_true, threshold=0.5):
    """
    Calculate the multi-label accuracy given the predicted and true labels.
    
    Args:
    - y_pred (torch.Tensor): Predicted labels (probabilities) from the model. Shape (batch_size, num_classes).
    - y_true (torch.Tensor): True labels (binary) for each sample. Shape (batch_size, num_classes).
    - threshold (float): Threshold value for determining the predicted labels.
    
    Returns:
    - float: Multi-label accuracy.
    """
    # Apply threshold to predicted labels
    y_pred_labels = (y_pred > threshold).float()
    
    # Calculate accuracy for each sample
    sample_accuracy = torch.sum(y_pred_labels == y_true, dim=1).float()/y_true.shape[1]
    # sample_accuracy = correct_samples == y_true.shape[1]
    
    # Calculate overall accuracy
    accuracy = torch.mean(sample_accuracy.float(), dim=0)
    
    return accuracy.item()

def multi_category_accuracy(y_pred, y_true):
    Y_pred = select_attr(y_pred, ONEHOT)
    Y_true = select_attr(y_true, ONEHOT)
    return torch.mean(torch.sum(Y_pred == Y_true, dim=1).float()/Y_true.shape[1], dim=0).item()

# evaluation
def compute_avg_class_acc(y_pred, y_true):
    # if ONEHOT:
    pred_labels = select_attr(y_pred, ONEHOT)
    gt_labels = select_attr(y_true, ONEHOT)

    # print(pred_labels[1], gt_labels[1])
    # else: 
    #     pred_labels = y_pred
    #     gt_labels = y_true
    
    num_attr = 6
    num_classes = [7, 3, 3, 4, 6, 3]  # number of classes in each attribute
    
    per_class_acc = []
    for attr_idx in range(num_attr):
        for idx in range(num_classes[attr_idx]):
            target = gt_labels[:, attr_idx]#.cpu().detach().numpy()
            pred = pred_labels[:, attr_idx]#.cpu().detach().numpy()
            correct = torch.sum((target == pred) * (target == idx))
            total = torch.sum(target == idx)
            # print(correct, total)
            per_class_acc.append(float(correct) / float(total) if float(total) != 0 else float(1))  # if float(correct) != float(0) else float(0)

    return sum(per_class_acc) / len(per_class_acc)

def mixup_data(x, y, alpha=1.0):
    # Generate a random weight for mixing the data
    # lam = torch.rand(x.size(0), 1).to(x.device)
    lam = torch.rand(1).to(x.device)
    lam = torch.max(lam, 1 - lam)
    
    # Generate a mixed image
    mixed_x = lam * x + (1 - lam) * x.flip(dims=(0,))
    
    # Generate mixed labels
    y_a, y_b = y, y.flip(dims=(0,))
    mixed_y = lam * y_a + (1 - lam) * y_b
    
    return mixed_x, mixed_y, y_a, y_b, lam
    
    

In [112]:
stat_training_loss = []
stat_val_loss = []
stat_training_acc = []
stat_val_acc = []
best_val_acc = 0
best_val_loss = float('inf')
partition = [0, 7, 10, 13, 17, 23, 26]
num_attr = 6
best_epoch = 0

In [None]:
if TRAINING:

    if lr_scheduler:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)

    start = time.time()
    for epoch in range(EPOCHS):
            training_loss = 0
            training_acc = 0
            batch_num = 0
            val_loss = 0
            val_acc = 0
            val_batch_num = 0
            # training
            model.train()
            for imgs, labels in train_loader:

                imgs = imgs.cuda()
                labels = labels.cuda()

                if MIXUP:
                    imgs, labels, y_a, y_b, lam = mixup_data(imgs, labels)
            
                batch_size = imgs.shape[0]
                optimizer.zero_grad()
                logits = model.forward(imgs)
                loss = None
                if CROSS:
                    for i in range(num_attr):
                        if loss == None:
                            loss = criterion(logits[:, partition[i] : partition[i+1]], labels[:, partition[i] : partition[i+1]])
                        else:
                            loss += criterion(logits[:, partition[i] : partition[i+1]], labels[:, partition[i] : partition[i+1]])
                elif MIXUP:
                    loss = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)
                else:
                    loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()
                
                training_acc += compute_avg_class_acc(logits, labels)#multi_label_accuracy(logits, labels) #
                
                training_loss += loss.item()
                batch_num += 1
                
            # validation
            model.eval()
            for val_imgs, val_labels in valid_loader:
                
                val_imgs = val_imgs.cuda()
                val_labels = val_labels.cuda()
                
                batch_size = val_imgs.shape[0]
                val_logits = model.forward(val_imgs)
                loss = val_criterion(val_logits, val_labels)
                val_acc += compute_avg_class_acc(val_logits.cuda(), val_labels)#multi_label_accuracy(val_logits.cuda(), val_labels) #
                val_loss += loss.item()
                val_batch_num += 1
            # assert val_batch_num == 10000
            # update stats
            stat_training_loss.append(training_loss/batch_num)
            stat_val_loss.append(val_loss/val_batch_num)
            stat_training_acc.append(training_acc/batch_num)
            stat_val_acc.append(val_acc/val_batch_num)
            # print
            print(f"Epoch {(epoch+1):d}/{EPOCHS:d}.. Train loss: {(training_loss/batch_num):.4f}.. Train acc: {(training_acc/batch_num):.4f}.. Val loss: {(val_loss/val_batch_num):.4f}.. Val acc: {(val_acc/val_batch_num):.4f}")
            # lr scheduler
            if lr_scheduler:
                scheduler.step()

            if epoch > 0 and best_val_loss > val_loss/val_batch_num : # val_acc/val_batch_num > best_val_acc
                best_val_acc = val_acc/val_batch_num
                best_val_loss = val_loss/val_batch_num
                best_epoch = epoch
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss
                }, f'{checkpoint_name}_best.pth')
            if epoch > 0 and epoch - best_epoch >= early_stopping:
                break
        # save checkpoint every 8 step
            # if epoch%8 == 0 and epoch > 0:
            #     torch.save({
            #         'model_state_dict': model.state_dict(),
            #         'optimizer_state_dict': optimizer.state_dict(),
            #         'loss': loss
            #     }, f'{checkpoint_name}_ep{epoch}.pth')

    print('Time: ', time.time() - start)
    np.savez(f'{checkpoint_name}.npz', stat_training_loss=stat_training_loss, stat_val_loss=stat_val_loss, stat_training_acc=stat_training_acc, stat_val_acc=stat_val_acc)


else:
    print("Not Training...")

## Evaluating Validation Accuracy

In [None]:
if EVALUATE:
    loader = valid_loader
    MIXUP = False

    model.eval()
    with torch.no_grad():
        category_acc = 0
        acc = 0
        loss = 0
        batch_num = 0
        for imgs, labels in loader:
            
            imgs = imgs.cuda()
            labels = labels.cuda()

            if MIXUP:
                imgs, labels, y_a, y_b, lam = mixup_data(imgs, labels)

            logits = model.forward(imgs)
            if MIXUP:
                loss = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)
            else:
                loss = criterion(logits, labels)
            # loss = criterion(logits, labels)
            acc += compute_avg_class_acc(logits.cuda(), labels)
            loss += loss.item()
            batch_num += 1
        print(f"Accuracy: {acc/batch_num} .. Loss: {loss/batch_num}")  

Accuracy: 0.8176965933605127 .. Loss: 0.003741858061403036


In [None]:
# # Load model checkpoint
# checkpoint = torch.load(f'{checkpoint_name}_best.pth')
# model.load_state_dict(checkpoint['model_state_dict'])
# # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# # loss = checkpoint['loss']
# epoch = checkpoint['epoch']
# print(epoch)

if PRINTING:
    model.eval()
    with torch.no_grad():
        def print_to_csv(y_pred):
            res = select_attr(y_pred, ONEHOT)
            with open('prediction.txt', 'a') as file:
                file.write('\n'.join([' '.join(map(str, row)) for row in res.to(torch.int).numpy()])) # np.array2string(res.numpy())
                file.write("\n")

        model.eval()
        category_acc = 0
        test_acc = 0
        test_loss = 0
        batch_count = 0
        for test_imgs, test_labels in test_loader:
            
            test_imgs = test_imgs.cuda()
            test_labels = test_labels.cuda()
            
            # batch_size = val_imgs.shape[0]
            test_logits = model.forward(test_imgs)
            # loss = criterion(val_logits, val_labels)
            print_to_csv(test_logits)
