In [None]:
import torch
import random
import numpy as np
from rarity_score import *
from torchvision import datasets, transforms, models
from tqdm import tqdm
from torch.utils.data import DataLoader

seed = 2302
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


transform_inception = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# transform_resnet = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])


# load datasets
# wider_dataset = datasets.WIDERFace("./widerface_data", split ="train", download=True, transform = transform_inception)
pet_dataset = datasets.OxfordIIITPet(root="./oxford_pet_data/", split="trainval", download=True, transform=transform_inception)



# load inception v3 feature extractor
inception = models.inception_v3(pretrained=True).to(device)
# resnet = models.resnet50(pretrained=True).to(device)

pet_loader = DataLoader(pet_dataset, batch_size=64, shuffle=True)
# wider_loader = DataLoader(wider_dataset, batch_size=64)


def extract_features_and_split(model, data_loader, fake_ratio=0.3):
    model.eval()  # Set the model to evaluation mode
    real_features_list = []
    fake_features_list = []
    total_batches = len(data_loader)
    fake_batches = int(total_batches * fake_ratio)

    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader)):
            images = batch[0].to(device)
            output = model(images)

            if i < fake_batches:
                fake_features_list.append(output.cpu())  # Add to fake list
            else:
                real_features_list.append(output.cpu())  # Add to real list

    # Concatenate all feature tensors to form single tensors for real and fake
    real_features = torch.cat(real_features_list, dim=0)
    fake_features = torch.cat(fake_features_list, dim=0)
    
    return real_features, fake_features

# Example usage
real_pet_features, fake_pet_features = extract_features_and_split(inception, pet_loader)

# wider_features = extract_features(inception, wider_loader)

