In [1]:
import torchvision
from torchvision import transforms
import torch
import os
from PIL import Image
from torch.utils.data import Dataset
import torch

os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

class PETA_Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = {}
        valid_extensions = {'.bmp', '.jpg', '.jpeg', '.png'}  # 支持的图片格式

        # 遍历所有subset文件夹，读取标签文件
        for subset in os.listdir(root_dir):
            subset_path = os.path.join(root_dir, subset, 'archive')
            if not os.path.isdir(subset_path):
                continue

            label_file = os.path.join(subset_path, 'Label.txt')

            if not os.path.exists(label_file):
                continue

            # 读取标签文件，并保存每个ID的标签信息
            with open(label_file, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    image_id = parts[0]
                    tags = parts[1:]

                    # 保存subset与ID组合后的唯一标识和personalMale标签的对应关系
                    unique_id = f"{subset}_{image_id}"
                    label = 1 if 'personalMale' in tags else 0
                    self.labels[unique_id] = label

            # 获取所有图片路径，支持多种图片格式
            for img_file in os.listdir(subset_path):
                ext = os.path.splitext(img_file)[-1].lower()
                if ext in valid_extensions:
                    # 获取文件名前的ID，即第一个下划线之前的部分
                    image_id = img_file.split('_')[0]

                    # 生成唯一ID，确保区分不同subset中的相同ID
                    unique_id = f"{subset}_{image_id}"

                    # 如果ID存在于标签文件中，则保存该图片路径和对应标签
                    if unique_id in self.labels:
                        img_path = os.path.join(subset_path, img_file)
                        self.image_paths.append((img_path, self.labels[unique_id]))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 获取图片路径和对应的标签
        img_path, label = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

In [2]:
Transform = transforms.Compose([transforms.Resize(size=(224, 224)), transforms.ToTensor()])
Full_Dataset = PETA_Dataset(root_dir='./PETA dataset', transform=Transform)

In [3]:
Num_Lables = 2
Pre_Trained_Model = 'google/vit-base-patch16-224-in21k'
ViT_Feature_Extractor = 'google/vit-base-patch16-224-in21k'
EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 2e-5
Num_Workers = 0
SHUFFLE = True

In [4]:
import torch.utils.data as data
from torch.autograd import Variable
import numpy as np

train_size = int(0.8 * len(Full_Dataset))
test_size = len(Full_Dataset) - train_size


train_ds, test_ds = torch.utils.data.random_split(Full_Dataset, [train_size, test_size])

print("Number of train samples: ", len(train_ds))
print("Number of test samples: ", len(test_ds))

Number of train samples:  15200
Number of test samples:  3800


In [5]:
train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE,  num_workers=Num_Workers, drop_last=True)
test_loader  = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=Num_Workers, drop_last=True)

In [6]:
from transformers import ViTModel
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn
import torch.nn.functional as F
import tqdm as notebook_tqdm

class ViTForImageClassification(nn.Module):
    def __init__(self, num_labels=Num_Lables):
        super(ViTForImageClassification, self).__init__()
        self.vit = ViTModel.from_pretrained(Pre_Trained_Model)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, pixel_values, labels):
        outputs = self.vit(pixel_values=pixel_values)
        output = self.dropout(outputs.last_hidden_state[:,0])
        logits = self.classifier(output)

        loss = None
#         if labels is not None:
#           loss_fct = nn.CrossEntropyLoss()
#           loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        if loss is not None:
          return logits, loss.item()
        else:
          return logits, None

In [7]:
from transformers import ViTFeatureExtractor
import torch.nn as nn
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
print(torch.cuda.is_available())

cuda
True


In [8]:
# Define Model
model = ViTForImageClassification(Num_Lables)
# Feature Extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(ViT_Feature_Extractor, do_rescale=False)
# Adam Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# Cross Entropy Loss
loss_func = nn.CrossEntropyLoss()
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model= nn.DataParallel(model)
if torch.cuda.is_available():
    model.cuda()



In [9]:
from tqdm.notebook import tqdm
# Train the model
acc = []
lss = []
ep = []
step = 1
for epoch in range(EPOCHS):
    for step, (x, y) in (p:=tqdm(enumerate(train_loader))):
        p.set_description(f'Epoch {epoch}: {step}/{len(train_loader)}')
        # Change input array into list with each batch being one element
        x = np.array_split(np.squeeze(np.array(x)), BATCH_SIZE)
        # Remove unecessary dimension
        for index, array in enumerate(x):
            x[index] = np.squeeze(array)
        # Apply feature extractor, stack back into 1 tensor and then convert to tensor

        x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
        # Send to GPU if available
        x, y  = x.to(device), y.to(device)
        b_x = Variable(x)   # batch x (image)
        b_y = Variable(y)   # batch y (target)
        # Feed through model
        output, loss_train = model(b_x, None)
        # Calculate loss
        if loss_train is None:
            loss_train = loss_func(output, b_y)
            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()
          # train_losses.append(loss_train.item())
        if step % 50 == 0:
            # Get the next batch for testing purposes
            test = next(iter(test_loader))
            test_x = test[0]
            # Reshape and get feature matrices as needed

            test_x = np.array_split(np.squeeze(np.array(test_x)), BATCH_SIZE)
            for index, array in enumerate(test_x):
                test_x[index] = np.squeeze(array)
            test_x = torch.tensor(np.stack(feature_extractor(test_x)['pixel_values'], axis=0))
            # Send to appropirate computing device
            test_x = test_x.to(device)
            test_y = test[1].to(device)
            # Get output (+ respective class) and compare to target
            test_output, loss_test = model(test_x, test_y)
            # val_losses.append(loss_test.item())
            test_output = test_output.argmax(1)
            # Calculate Accuracy
            accuracy = (test_output == test_y).sum().item() / BATCH_SIZE
            # accuracy_all.append(accuracy)
            print('Epoch: ', epoch, '| train loss: %.4f' % loss_train, '| test accuracy: %.2f' % accuracy)
            acc.append(accuracy)
            ep.append(step)
            step = step + 1
            lss.append(loss_train)

0it [00:00, ?it/s]

Epoch:  0 | train loss: 0.7130 | test accuracy: 0.62
Epoch:  0 | train loss: 0.4439 | test accuracy: 0.84
Epoch:  0 | train loss: 0.3076 | test accuracy: 0.88
Epoch:  0 | train loss: 0.2118 | test accuracy: 0.94
Epoch:  0 | train loss: 0.1811 | test accuracy: 0.94


0it [00:00, ?it/s]

Epoch:  1 | train loss: 0.1985 | test accuracy: 0.92
Epoch:  1 | train loss: 0.0989 | test accuracy: 0.89
Epoch:  1 | train loss: 0.1512 | test accuracy: 0.91
Epoch:  1 | train loss: 0.1815 | test accuracy: 1.00
Epoch:  1 | train loss: 0.1584 | test accuracy: 0.95


0it [00:00, ?it/s]

Epoch:  2 | train loss: 0.1484 | test accuracy: 0.94
Epoch:  2 | train loss: 0.0771 | test accuracy: 0.92
Epoch:  2 | train loss: 0.0491 | test accuracy: 0.94
Epoch:  2 | train loss: 0.0515 | test accuracy: 0.97
Epoch:  2 | train loss: 0.0207 | test accuracy: 0.91


0it [00:00, ?it/s]

Epoch:  3 | train loss: 0.0150 | test accuracy: 0.89
Epoch:  3 | train loss: 0.0184 | test accuracy: 0.94
Epoch:  3 | train loss: 0.0170 | test accuracy: 0.91
Epoch:  3 | train loss: 0.0215 | test accuracy: 0.89
Epoch:  3 | train loss: 0.0581 | test accuracy: 0.94


0it [00:00, ?it/s]

Epoch:  4 | train loss: 0.0113 | test accuracy: 0.98
Epoch:  4 | train loss: 0.0219 | test accuracy: 0.95
Epoch:  4 | train loss: 0.0706 | test accuracy: 0.95
Epoch:  4 | train loss: 0.0151 | test accuracy: 0.98
Epoch:  4 | train loss: 0.0295 | test accuracy: 0.98


0it [00:00, ?it/s]

Epoch:  5 | train loss: 0.0199 | test accuracy: 0.97
Epoch:  5 | train loss: 0.0069 | test accuracy: 0.95
Epoch:  5 | train loss: 0.0051 | test accuracy: 0.91
Epoch:  5 | train loss: 0.0187 | test accuracy: 0.91
Epoch:  5 | train loss: 0.0081 | test accuracy: 0.94


0it [00:00, ?it/s]

Epoch:  6 | train loss: 0.0047 | test accuracy: 0.98
Epoch:  6 | train loss: 0.0043 | test accuracy: 0.94
Epoch:  6 | train loss: 0.0179 | test accuracy: 0.89
Epoch:  6 | train loss: 0.0788 | test accuracy: 0.94
Epoch:  6 | train loss: 0.0061 | test accuracy: 0.92


0it [00:00, ?it/s]

Epoch:  7 | train loss: 0.0033 | test accuracy: 0.95
Epoch:  7 | train loss: 0.0045 | test accuracy: 0.97
Epoch:  7 | train loss: 0.0362 | test accuracy: 0.97
Epoch:  7 | train loss: 0.0031 | test accuracy: 0.95
Epoch:  7 | train loss: 0.0043 | test accuracy: 0.86


0it [00:00, ?it/s]

Epoch:  8 | train loss: 0.0939 | test accuracy: 0.94
Epoch:  8 | train loss: 0.0032 | test accuracy: 0.97
Epoch:  8 | train loss: 0.0040 | test accuracy: 0.98
Epoch:  8 | train loss: 0.0427 | test accuracy: 0.95
Epoch:  8 | train loss: 0.0053 | test accuracy: 0.98


0it [00:00, ?it/s]

Epoch:  9 | train loss: 0.0027 | test accuracy: 0.91
Epoch:  9 | train loss: 0.0066 | test accuracy: 0.95
Epoch:  9 | train loss: 0.0023 | test accuracy: 0.94
Epoch:  9 | train loss: 0.0022 | test accuracy: 0.97
Epoch:  9 | train loss: 0.0057 | test accuracy: 0.94
