In [1]:
import json

In [2]:
with open('train_no_dup_new_100.json', 'r') as f:
        outfits = json.load(f)


print(len(outfits))

16983


In [4]:
import torch
from torch_geometric.data import Data, Dataset
import os
import os.path as osp
from tqdm import tqdm

import torchvision.transforms as transforms
from torchvision.models import inception_v3
from PIL import Image

model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
model.fc = torch.nn.Identity()  # Replace the classification layer with an identity layer
model.eval() # Set the model to evaluation mode


class outfitsDataset(Dataset):
    def __init__(self, root, outfits,transform=None, pre_transform=None, pre_filter=None):
        self.outfits = outfits
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return os.listdir('..\\Outfit-Recommender-GNN\\outfitsData')

    @property
    def processed_file_names(self):
        return [f'{file}' for file in os.listdir(self.processed_dir)]

    def preprocess_image(self, img_path):
        
        transform = transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        img = Image.open(img_path)
        if img.mode != 'RGB':
            #print('false')
            img = img.convert("RGB")
        img_tensor = transform(img).unsqueeze(0)  # Add batch dimension
        return img_tensor
    
    def extract_features(self, img_path, model):
        img_tensor = self.preprocess_image(img_path)
        with torch.no_grad():  # Disable gradient calculation for inference
            features = model(img_tensor)
        return features.squeeze()  # Remove the batch dimension
        
    def process_outfit(self, set_id, indexes):
        images = []
        for dirname in os.listdir('..\\Outfit-Recommender-GNN\\images'):
            if dirname == set_id:
                set_path = os.path.join('..\\Outfit-Recommender-GNN\\images', dirname)
                i = 0
                for image in os.listdir(set_path):
                    try:
                        if i in indexes:
                            img_path = os.path.join(set_path, image)
                            image_features = self.extract_features(img_path, model)
                            images.append(image_features)
                            #print(image_features)
                            #print('complete')
                    except Exception as e:
                        print(f"Error loading '{set_path}'\\{image}: {e}")
                    i += 1
                return images
            
        return "could not find"
        
    def process(self):
        idx = 0
        for outfit in tqdm(self.outfits):
            a_outift = self.process_outfit(outfit['set_id'], outfit['items_index'])
            edge_index = []
            for i in range(len(outfit)):
                for j in range(len(outfit)):
                    if i == j:
                        continue
                    edge_index.append([i, j])
            edge_index = torch.tensor(edge_index)
            x = a_outift
            data = Data(x=x, edge_index=edge_index.t().contiguous())
            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx+=1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

Using cache found in C:\Users\Admin/.cache\torch\hub\pytorch_vision_v0.10.0


In [5]:
outfits = outfits[:1000]
outfits_graphs = outfitsDataset('..\\Outfit-Recommender-GNN\\outfitsData', outfits=outfits)
outfits_graphs.process()
print(outfits_graphs[1].x[0].size())

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [11:37<00:00,  1.43it/s]

torch.Size([2048])



