# Art Style Explorer Demo

This notebook demonstrates the workflow of the Art Style Explorer project, which analyzes artwork images to find stylistically similar artists and artworks.

In [1]:
!pip install -r ../requirements.txt



In [2]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
from PIL import Image

# Add parent directory to path
sys.path.append('..')

## 1. Setup

First, let's import the necessary modules from our project.

In [3]:
from src.preprocessing.edge_detection import process_artwork, detect_contours, analyze_contours
from src.preprocessing.normalization import standardize_artwork
from src.features.line_features import extract_line_features, extract_hough_lines, visualize_lines
from src.features.composition import extract_composition_features
from src.utils.visualization import (
    visualize_edges,
    visualize_line_features,
    visualize_composition_features,
    visualize_similar_artworks
)

## 2. Load and Process an Artwork

Now, let's load an artwork image and process it.

In [4]:
# Path to an artwork image
image_path = '../sample_data/starry_night.jpg'  # Replace with your image path

# Process the artwork to extract line work
preprocessed, edges = process_artwork(
    image_path, 
    target_size=(512, 512),
    edge_method='canny'
)

# Visualize the original image and edge detection
fig = visualize_edges(preprocessed, edges)
plt.show()

[ WARN:0@122.287] global loadsave.cpp:248 findDecoder imread_('../sample_data/starry_night.jpg'): can't open/read file: check file path/integrity


FileNotFoundError: Image not found at ../sample_data/starry_night.jpg

## 3. Extract Line Features

Let's extract and visualize line features from the artwork.

In [None]:
# Detect contours in the edge image
contours = detect_contours(edges)

# Extract Hough lines
lines = extract_hough_lines(edges)

# Visualize the lines
line_fig = visualize_line_features(preprocessed, lines)
plt.show()

# Extract line features
line_features = extract_line_features(edges, contours)

# Print some key line features
print(f"Number of lines: {line_features['line_count']}")
print(f"Mean line length: {line_features['mean_length']:.2f}")
print(f"Horizontal ratio: {line_features['horizontal_ratio']:.2f}")
print(f"Vertical ratio: {line_features['vertical_ratio']:.2f}")
print(f"Diagonal ratio: {line_features['diagonal_ratio']:.2f}")
print(f"Mean curvature: {line_features['mean_curvature']:.2f}")
print(f"Number of intersections: {line_features['intersection_count']}")
print(f"Number of contours: {line_features['contour_count']}")
print(f"Average complexity: {line_features['avg_complexity']:.2f}")

## 4. Extract Composition Features

Now, let's extract and visualize composition features.

In [None]:
# Extract composition features
composition_features = extract_composition_features(preprocessed)

# Visualize composition features
comp_fig = visualize_composition_features(preprocessed, composition_features)
plt.show()

# Print some key composition features
print(f"Rule of thirds adherence: {composition_features['thirds_adherence']:.2f}")
print(f"Horizontal symmetry: {composition_features['horizontal_symmetry']:.2f}")
print(f"Vertical symmetry: {composition_features['vertical_symmetry']:.2f}")
print(f"Overall symmetry: {composition_features['overall_symmetry']:.2f}")
print(f"Golden ratio adherence: {composition_features['golden_ratio_adherence']:.2f}")
print(f"Horizontal balance: {composition_features['horizontal_balance']:.2f}")
print(f"Vertical balance: {composition_features['vertical_balance']:.2f}")
print(f"Radial balance: {composition_features['radial_balance']:.2f}")

## 5. Create a Feature Vector

Let's combine the line and composition features into a single feature vector.

In [None]:
# Define a function to convert features to a vector (same as in main.py)
def get_features_vector(line_features, composition_features):
    # Select numerical features from line_features
    line_values = [
        line_features['line_count'],
        line_features['mean_length'],
        line_features['std_length'],
        line_features['max_length'],
        line_features['min_length'],
        line_features['horizontal_ratio'],
        line_features['vertical_ratio'],
        line_features['diagonal_ratio'],
        line_features['mean_curvature'],
        line_features['std_curvature'],
        line_features['max_curvature'],
        line_features['min_curvature'],
        line_features['intersection_count'],
        line_features['contour_count'],
        line_features['avg_complexity']
    ]
    
    # Add orientation histogram
    line_values.extend(line_features['orientation_histogram'])
    
    # Select numerical features from composition_features
    comp_values = [
        composition_features['horizontal_line_energy'],
        composition_features['vertical_line_energy'],
        composition_features['intersection_energy'],
        composition_features['thirds_adherence'],
        composition_features['horizontal_symmetry'],
        composition_features['vertical_symmetry'],
        composition_features['diagonal_symmetry'],
        composition_features['overall_symmetry'],
        composition_features['golden_horizontal_energy'],
        composition_features['golden_vertical_energy'],
        composition_features['golden_spiral_energy'],
        composition_features['golden_ratio_adherence'],
        composition_features['horizontal_balance'],
        composition_features['vertical_balance'],
        composition_features['radial_balance']
    ]
    
    # Combine line and composition features
    features_vector = np.array(line_values + comp_values, dtype=np.float32)
    
    return features_vector

# Create the feature vector
features_vector = get_features_vector(line_features, composition_features)

# Print the shape of the feature vector
print(f"Feature vector shape: {features_vector.shape}")

# Plot the feature vector as a bar chart
plt.figure(figsize=(15, 5))
plt.bar(range(len(features_vector)), features_vector)
plt.title('Feature Vector')
plt.xlabel('Feature Index')
plt.ylabel('Feature Value')
plt.show()

## 6. Load the Neural Network Model

Now, let's load a pre-trained neural network model for art style analysis.

In [None]:
# Import the model
from src.model.network import ArtStyleNetwork
from src.model.training import load_model

# Initialize the model
model = ArtStyleNetwork(
    input_channels=1,  # Grayscale images
    line_feature_dim=128,
    comp_feature_dim=64,
    embedding_dim=256,
    num_classes=100,  # Placeholder for number of artist classes
    comp_input_dim=len(features_vector)  # Use actual feature vector length
)

# Use CPU for now (replace with GPU if available)
device = 'cpu'

# In a real scenario, you would load a pre-trained model like this:
# model = load_model(model, 'path/to/checkpoint.pth', device=device)

# For demonstration, we'll just use the uninitialized model
model = model.to(device)

## 7. Extract Embedding

Let's extract an art style embedding from our image using the model.

In [None]:
# Convert the image and features to torch tensors
image_tensor = torch.tensor(preprocessed, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0  # Normalize
features_tensor = torch.tensor(features_vector, dtype=torch.float32).unsqueeze(0)

# Move to device
image_tensor = image_tensor.to(device)
features_tensor = features_tensor.to(device)

# Extract embedding
with torch.no_grad():
    embedding = model.extract_features(image_tensor, features_tensor)

# Print the embedding shape
print(f"Embedding shape: {embedding.shape}")

# Plot the embedding as a bar chart
plt.figure(figsize=(15, 5))
plt.bar(range(embedding.shape[1]), embedding.cpu().numpy()[0])
plt.title('Art Style Embedding')
plt.xlabel('Dimension')
plt.ylabel('Value')
plt.show()

## 8. Find Similar Artworks (Simulation)

Since we don't have a database of artworks in this demo, let's simulate finding similar artworks by generating random embeddings.

In [None]:
# Import functions for finding similar artworks
import torch.nn.functional as F

# Simulate a database of artworks
num_database_artworks = 100
embedding_dim = embedding.shape[1]

# Generate random embeddings for the database
# In a real application, these would be extracted from actual artworks
database_embeddings = torch.randn(num_database_artworks, embedding_dim)

# Normalize the embeddings
database_embeddings = F.normalize(database_embeddings, p=2, dim=1)

# Simulate artist names and artwork titles
artists = [f"Artist {i % 20 + 1}" for i in range(num_database_artworks)]
titles = [f"Artwork {i + 1}" for i in range(num_database_artworks)]

# Calculate cosine similarity between the query embedding and database embeddings
similarities = F.cosine_similarity(embedding, database_embeddings)

# Find the top-5 most similar artworks
top_k = 5
top_indices = torch.topk(similarities, k=top_k).indices.cpu().numpy()
top_similarities = similarities[top_indices].cpu().numpy()

# Print the results
print("Top 5 similar artworks:")
for i, (idx, sim) in enumerate(zip(top_indices, top_similarities)):
    print(f"{i+1}. {artists[idx]} - {titles[idx]} (Similarity: {sim:.4f})")

## 9. Visualize Similar Artworks (Simulation)

Let's simulate visualizing the query image and similar artworks.

In [None]:
# Load the original query image
query_image = cv2.imread(image_path)
query_image = cv2.cvtColor(query_image, cv2.COLOR_BGR2RGB)

# Generate random similar images for demonstration
# In a real application, these would be actual similar artworks
similar_images = [np.random.randint(0, 256, (512, 512, 3), dtype=np.uint8) for _ in range(top_k)]

# Visualize the query image and similar artworks
similar_fig = visualize_similar_artworks(
    query_image, 
    similar_images, 
    [artists[idx] for idx in top_indices], 
    [titles[idx] for idx in top_indices], 
    top_similarities
)

plt.show()

## 10. Complete Workflow

In a real application, you would:

1. Collect a large dataset of artwork images with associated metadata
2. Train the neural network model using the training script
3. Extract features from all artworks in the database
4. Use the main application to analyze new artwork images and find similar styles

The command-line interface makes it easy to use the application:

In [None]:
print("Training command example:")
print("python train.py \
    --data-dir /path/to/artwork/images \
    --metadata /path/to/metadata.csv \
    --output-dir checkpoints \
    --batch-size 32 \
    --epochs 50 \
    --device cuda \
    --data-augmentation")

print("\nInference command example:")
print("python main.py \
    --input /path/to/query/image.jpg \
    --output-dir results \
    --database /path/to/metadata.csv \
    --data-dir /path/to/artwork/images \
    --mappings checkpoints/artist_mappings.json \
    --model checkpoints/model_final.pth \
    --top-k 5 \
    --visualize")