In [1]:
import fiftyone as fo
import torchvision
import os 
import torch
from PIL import Image
import numpy as np
import fiftyone.brain as fob
from sklearn.neighbors import NearestNeighbors


## Create Dataset

In [2]:
samples = []

data_root2 = '../Data/Snickers_Real_Image/10/'
data_root1 = '../Data/Snickers_Digital_Twin/10/'

folder1 = os.listdir(data_root1)
folder2 = os.listdir(data_root2)


for count, file in enumerate(folder1):
    if count >len(folder2):
        break
    sample = fo.Sample(filepath = os.path.join(data_root1,file))
    label = 'Synthetic'
    sample["ground_truth"] = fo.Classification(label=label)
    samples.append(sample)

for count, file in enumerate(folder2):
    sample = fo.Sample(filepath = os.path.join(data_root2,file))
    label = 'Real'
    sample["ground_truth"] = fo.Classification(label=label)
    samples.append(sample)

# Create dataset
dataset = fo.Dataset("my-classification-dataset")
dataset.add_samples(samples)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

 100% |█████████████████| 281/281 [116.5ms elapsed, 0s remaining, 2.4K samples/s]  


## Image loading 

In [3]:
def pil_loader(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
def get_image(path,transform):
    
    x = pil_loader(path)
    x = transform(x)
    return(x)

## Calculate image normalization

In [4]:
train_transform = torchvision.transforms.Compose([ 
    torchvision.transforms.Resize(size = (256,256)),
    torchvision.transforms.ToTensor(),
])

for count, file in enumerate(folder1):
    im_file = os.path.join(data_root1, file)
    images = get_image(im_file, train_transform)
    if count==0:
        image_stats = images.view(3,-1)
    else:
        image_stats = torch.cat((images.view(3,-1), image_stats), dim = 1)

R = image_stats[0,:]
G = image_stats[1,:]
B = image_stats[2,:]

R1 = R[R!=0].mean()
G1 = G[G!=0].mean()
B1 = B[B!=0].mean()
Rstd1 = (R[R!=0].std())
Gstd1 = (G[G!=0].std())
Bstd1 = (B[B!=0].std())

for count, file in enumerate(folder2):
    im_file = os.path.join(data_root2, file)
    images = get_image(im_file, train_transform)
    if count==0:
        image_stats = images.view(3,-1)
    else:
        image_stats = torch.cat((images.view(3,-1), image_stats), dim = 1)
        
R = image_stats[0,:]
G = image_stats[1,:]
B = image_stats[2,:]

R2 = R[R!=0].mean()
G2 = G[G!=0].mean()
B2 = B[B!=0].mean()
Rstd2 = (R[R!=0].std())
Gstd2 = (G[G!=0].std())
Bstd2 = (B[B!=0].std())


## Image Normalization

In [5]:
fake_transform = torchvision.transforms.Compose([ 
    torchvision.transforms.Resize(size = (256,256)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[R1, G1, B1],
                 std=[Rstd1, Gstd1, Bstd1]),
    ## undo fasterrcnn normalization
    torchvision.transforms.Normalize(mean = [ 0., 0., 0. ],
                                 std = [ 1/0.229, 1/0.224, 1/0.225 ]),
    torchvision.transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                 std = [ 1., 1., 1. ]),
])

real_transform = torchvision.transforms.Compose([ 
    torchvision.transforms.Resize(size = (256,256)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[R2, G2, B2],
             std=[Rstd2, Gstd2, Bstd2]),
    ## undo fasterrcnn normalization
    torchvision.transforms.Normalize(mean = [ 0., 0., 0. ],
                                 std = [ 1/0.229, 1/0.224, 1/0.225 ]),
    torchvision.transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                 std = [ 1., 1., 1. ]),

])

## Setup pre-trained model for generating embeddings

In [6]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).to(device)

class SaveFeatures():
    features=None
    def __init__(self, m): 
        self.hook = m.register_forward_hook(self.hook_fn)
        self.features = None
    def hook_fn(self, module, input, output): 
        out = output.detach()#.cpu().numpy()
        self.features = out.flatten()

    def remove(self): 
        self.hook.remove()
        
hook = SaveFeatures(model.roi_heads.box_head.fc7)


def model_embedding(model, images):
    ## renset50
    _ = model(images)
    return hook.features[0:1024]

## Generate Embeddings

In [7]:
model.eval()
n_features = 1024
num_images = len(dataset.values('filepath'))

with torch.no_grad():
    
    embedding = torch.zeros(num_images, n_features)
    index_count = 0 
    for index, (f, label_dict) in enumerate(zip(dataset.values('filepath'), dataset.values('ground_truth'))):
        if 'Synthetic' in label_dict['label']:
            transform = fake_transform
        

        elif 'Real'  in label_dict['label']:
            transform = real_transform
        images = get_image(f, transform)
        images = images.unsqueeze(0).to(device)  
        
        out = model_embedding(model, images.to(device))
        embedding[index,:] = (out)

        


torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)



## umap Dimensionality reduction

In [8]:
results = fob.compute_visualization(
    dataset,
    embeddings=embedding.detach().cpu().numpy(),
    num_dims=2,
    method="umap",
    brain_key="None",
    verbose=True,
    seed=51,
)


Generating visualization...
UMAP(dens_frac=0.0, dens_lambda=0.0, random_state=51, verbose=True)
Construct fuzzy simplicial set



The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107. The TBB threading layer is disabled.



Fri May 20 11:39:58 2022 Finding Nearest Neighbors
Fri May 20 11:40:00 2022 Finished Nearest Neighbor Search
Fri May 20 11:40:02 2022 Construct embedding
	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	completed  300  /  500 epochs
	completed  350  /  500 epochs
	completed  400  /  500 epochs
	completed  450  /  500 epochs
Fri May 20 11:40:03 2022 Finished embedding


## Voxel 51 visualization

In [9]:
session = fo.launch_app(view=dataset.view())
plot = results.visualize(labels="ground_truth.label")
plot.show(height=720)
session.plots.attach(plot)





FigureWidget({
    'data': [{'customdata': array(['6287e027ad35272bd64cb40e', '6287e027ad35272bd64cb40f',
    …

## Calculate data overlap

In [10]:
y = np.array([int('Real' in d['label']) for d in dataset.values('ground_truth')])
X = results.points
nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)
nearest_class = []

## find nearest points to Synthetic data
for data in X[y==0]:
    distances, indices = nbrs.kneighbors(data.reshape(1, -1))
    nearest_class.append(y[indices[0][1]])    

n_real = np.sum([int(x==1) for x in nearest_class])
n_fake = np.sum([int(x==0) for x in nearest_class])


ns = (len(nearest_class))
z = 0
x = n_fake/len(nearest_class)
a = x
b = z
c = np.sqrt(ns)
p_val = (-np.sqrt(-4*a**2*b**2*c**2+4*a*b**2*c**2 + b**4) + 2*a*c**2 + b**2) / (2*(b**2+c**2))

nd = len(y)
nq = (y==1).sum() 
data_overlap = -(p_val-1) * (nd / nq)
print('data overlap = %f' %data_overlap)

data overlap = 0.854103
