In [None]:
import torch
import os
from glob import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
# 사용자 정의 데이터셋 클래스 정의
class ImageDataset(Dataset):
    def __init__(self, data_root, transform=transform):
        self.data_root = data_root
        self.transform = transform
        self.image_paths = glob(os.path.join(data_root, 'batch*', '*', '*','*.jpg'))
        # sort
        self.image_paths = sorted(self.image_paths)
        # half
        self.image_paths = self.image_paths[:len(self.image_paths)//2+1]
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        data = {}
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        data['image'] = image
        data['image_path'] = image_path
        return data

In [None]:
from transformers import ViTImageProcessor, ViTModel
from torch.utils.data import DataLoader
import numpy as np
import os
import torch
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 사전 훈련된 모델과 특징 추출기 로드
image_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device)
model.eval()

# dataloader
dataset = ImageDataset(data_root='/home/minseongjae/Downloads/AffWild2')
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)

# 데이터 순회하며 특징 추출 및 저장
with torch.no_grad():
    for data in tqdm(dataloader):
        # 특징 추출
        inputs = image_processor(images=data['image'], return_tensors="pt", do_rescale=False, do_resize=False).to(device)
        outputs = model(**inputs)
        features = outputs.pooler_output
        for i, path in enumerate(data['image_path']):
            path = path.replace('cropped_aligned', 'Features').replace('.jpg', '.npy')
            os.makedirs(os.path.dirname(path), exist_ok=True)
            np.save(path, features[i].to('cpu').numpy())