In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

class ColorAutoencoder(nn.Module):
    def __init__(self, input_dim, bottleneck_dim):
        super(ColorAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, bottleneck_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

class ColorCategoryGame:
    def __init__(self, num_agents, input_dim, bottleneck_dim):
        self.num_agents = num_agents
        self.agents = [ColorAutoencoder(input_dim, bottleneck_dim) for _ in range(num_agents)]
        self.optimizers = [optim.Adam(agent.parameters()) for agent in self.agents]
        self.criterion = nn.MSELoss()

    def generate_color_stimuli(self, batch_size):
        # Generate color stimuli according to JND function
        # This is a placeholder and should be implemented based on the actual JND function
        return torch.rand(batch_size, 3)  # RGB colors

    def play_game(self, num_rounds):
        for round in range(num_rounds):
            speaker_idx = np.random.randint(self.num_agents)
            hearer_idx = np.random.randint(self.num_agents)
            while hearer_idx == speaker_idx:
                hearer_idx = np.random.randint(self.num_agents)

            speaker = self.agents[speaker_idx]
            hearer = self.agents[hearer_idx]

            stimuli = self.generate_color_stimuli(batch_size=2)
            topic = stimuli[0].unsqueeze(0)

            # Speaker encodes the topic
            speaker_encoding, _ = speaker(topic)

            # Hearer tries to decode the encoding
            _, hearer_reconstruction = hearer(speaker_encoding)

            # Calculate loss and update hearer
            loss = self.criterion(hearer_reconstruction, topic)
            self.optimizers[hearer_idx].zero_grad()
            loss.backward()
            self.optimizers[hearer_idx].step()

            # If loss is high, update speaker as well
            if loss.item() > 0.1:  # This threshold can be adjusted
                self.optimizers[speaker_idx].zero_grad()
                loss.backward()
                self.optimizers[speaker_idx].step()

            if round % 1000 == 0:
                print(f"Round {round}, Loss: {loss.item()}")

    def analyze_categories(self):
        # Generate a range of colors
        colors = torch.linspace(0, 1, 100).unsqueeze(1).repeat(1, 3)

        # Encode colors with all agents
        encodings = []
        for agent in self.agents:
            encoding, _ = agent(colors)
            encodings.append(encoding.detach().numpy())

        # Analyze the structure of these encodings to identify emerged categories
        # This is a placeholder and should be implemented based on the specific analysis needed
        plt.figure(figsize=(10, 5))
        plt.imshow(np.array(encodings).T, aspect='auto', cmap='viridis')
        plt.colorbar()
        plt.title("Color Encodings Across Agents")
        plt.xlabel("Color (from 0 to 1 in RGB space)")
        plt.ylabel("Agents and their encoding dimensions")
        plt.show()

# Usage
game = ColorCategoryGame(num_agents=10, input_dim=3, bottleneck_dim=5)
game.play_game(num_rounds=100000)
game.analyze_categories()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x5 and 3x64)