In [2]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from transformers import AutoImageProcessor, AutoModel
from tqdm.notebook import tqdm
from PIL import Image
from stable_baselines3 import DQN

In [3]:
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small")
model = AutoModel.from_pretrained("facebook/dinov2-small")
model = model.to("mps")

In [4]:
target_img = Image.open("./grant-headshot.png")

In [5]:
def get_img_vects(imgs):
    inputs = processor(images=imgs, return_tensors="pt")
    inputs = {k: v.to("mps") for k, v in inputs.items()}
    outputs = model(**inputs)
    # vects = outputs.last_hidden_state.mean(axis=1).detach().cpu().numpy()
    vects = outputs.last_hidden_state[:, 1:, :].detach().cpu().numpy()
    # normalize
    # vects /= np.linalg.norm(vects, axis=1, keepdims=True)
    return vects

In [6]:
target_vect = get_img_vects(target_img)[0]

  return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors)


In [15]:
class PixelArtDQNEnv(gym.Env):
    """A custom environment for evolving pixel art using DQN."""
    def __init__(self, target_image, img_size=32):
        super(PixelArtDQNEnv, self).__init__()
        self.target_image = target_image
        self.target_vect = get_img_vects(target_image)[0]
        self.img_size = img_size
        # Define the action and observation space
        self.action_space = spaces.Discrete(img_size * img_size * 2)  # *2 for toggle actions
        self.observation_space = spaces.Box(low=0, high=255, shape=(img_size * img_size,), dtype=np.uint8)
        self.state = np.random.randint(2, size=(img_size * img_size), dtype=np.uint8)


    def step(self, action):
        # Determine pixel and state from action
        pixel_index, state = divmod(action, 2)
        self.state[pixel_index] = state
        
        # Convert state to image for evaluation
        img = Image.fromarray(self.state.reshape(self.img_size, self.img_size) * 255).convert("RGB")
        # Implement your similarity calculation here
        reward = self.calculate_err(img)
        
        observation = self.state
        info = {}
        terminated = False  # e.g., if some termination condition is met
        truncated = False  # e.g., if the episode reaches a time limit
        
        return observation, reward, terminated, truncated, info
    

    def reset(self, **kwargs):  # Updated to accept arbitrary keyword arguments
        self.state = np.random.randint(2, size=(self.img_size * self.img_size), dtype=np.uint8)
        reset_info = {}  
        return self.state, reset_info  # Return both state and reset_info as a tuple

    
    def calculate_err(self, img):
        # Placeholder for similarity calculation logic
        # return np.random.random()  # Replace with actual implementation
        # img = Image.fromarray(img.reshape(self.img_size, self.img_size) * 255).convert("RGB")
        # Calculate similarity to target image (placeholder function)
        img_vect = get_img_vects(img)[0]
        similarity = ((img_vect - self.target_vect) ** 2).mean()
        return similarity

In [16]:
# Initialize your target image and environment
# target_image = np.random.rand(32, 32, 3)  # Replace with the actual target image
env = PixelArtDQNEnv(target_image=target_img, img_size=32)

rl_model = DQN("MlpPolicy", env, verbose=3, learning_rate=1e-4, buffer_size=10000, learning_starts=1000, batch_size=32, tau=1.0, gamma=0.99, train_freq=4, gradient_steps=1, optimize_memory_usage=False, target_update_interval=500, exploration_fraction=0.1, exploration_initial_eps=1.0, exploration_final_eps=0.05, max_grad_norm=10)
rl_model.learn(total_timesteps=20000)


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


<stable_baselines3.dqn.dqn.DQN at 0x2f74eb4d0>