In [1]:
import os
from torchvision import datasets, transforms

# Define a transformation for the dataset
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

# Load an existing dataset
# You can also replace 'root' with a path to a local image dataset directory
dataset = datasets.FakeData(size=1000, image_size=(3, 224, 224), transform=transform)


In [2]:
import dask.bag as db
from PIL import Image
import numpy as np

# Define a function to simulate processing, like resizing or normalizing
def process_image(image_tensor):
    # Example: Convert tensor to PIL Image (assuming image_tensor is a tensor)
    image = transforms.ToPILImage()(image_tensor)
    image = image.resize((128, 128))  # Example resizing
    return np.array(image)

# Use Dask to process images in parallel
images = [dataset[i][0] for i in range(len(dataset))]  # Extract images
bag = db.from_sequence(images).map(process_image)
processed_images = bag.compute()


In [3]:
import torch
import torchvision.models as models

# Load a pre-trained ResNet model
model = models.resnet50(pretrained=True)
model.eval()

# Function to extract features
def extract_features(image_tensor):
    with torch.no_grad():
        features = model(image_tensor.unsqueeze(0))  # Add batch dimension
    return features.squeeze().numpy()

# Process images in parallel to extract features
features = [extract_features(image) for image in images]


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 86.4MB/s]


In [4]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Convert features to a NumPy array for processing
features_array = np.array(features)

# Step 1: Use PCA to reduce dimensions to 50 (optional)
pca = PCA(n_components=50)
pca_features = pca.fit_transform(features_array)

# Step 2: Apply t-SNE for 2D visualization
tsne = TSNE(n_components=2, perplexity=30, random_state=0)
reduced_embeddings = tsne.fit_transform(pca_features)


In [6]:
!pip install dash
!pip install plotly

Collecting dash
  Downloading dash-2.18.1-py3-none-any.whl.metadata (10 kB)
Collecting dash-html-components==2.0.0 (from dash)
  Downloading dash_html_components-2.0.0-py3-none-any.whl.metadata (3.8 kB)
Collecting dash-core-components==2.0.0 (from dash)
  Downloading dash_core_components-2.0.0-py3-none-any.whl.metadata (2.9 kB)
Collecting dash-table==5.0.0 (from dash)
  Downloading dash_table-5.0.0-py3-none-any.whl.metadata (2.4 kB)
Collecting retrying (from dash)
  Downloading retrying-1.3.4-py3-none-any.whl.metadata (6.9 kB)
Downloading dash-2.18.1-py3-none-any.whl (7.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dash_core_components-2.0.0-py3-none-any.whl (3.8 kB)
Downloading dash_html_components-2.0.0-py3-none-any.whl (4.1 kB)
Downloading dash_table-5.0.0-py3-none-any.whl (3.9 kB)
Downloading retrying-1.3.4-py3-none-any.whl (11 kB)
Installing collected packages: dash-table, dash-html-comp

In [7]:
from dash import Dash, html, dcc
import plotly.express as px

app = Dash(__name__)

# Create a scatter plot of the 2D embeddings
fig = px.scatter(
    x=reduced_embeddings[:, 0],
    y=reduced_embeddings[:, 1],
    title="2D Visualization of Image Embeddings"
)

# Define the layout of the Dash app
app.layout = html.Div([
    html.H1("Image Data Embeddings Visualization"),
    dcc.Graph(id='scatter-plot', figure=fig)
])

# Run the Dash app
if __name__ == '__main__':
    app.run_server(debug=True)


<IPython.core.display.Javascript object>