In [3]:
import gymnasium as gym
import numpy as np
import permuted_mnist 

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from gymnasium.wrappers import TimeLimit

class MNISTLogisticRegression(nn.Module):
    def __init__(self, input_size=784):  # 28x28 = 784 pixels per image
        super().__init__()
        self.flatten = nn.Flatten()
        # Logistic regression is effectively a linear layer followed by softmax
        self.logistic = nn.Sequential(
            nn.Linear(input_size, 10),  # 10 classes (digits 0-9)
            nn.LogSoftmax(dim=1)  # Log softmax for numerical stability
        )
        
    def forward(self, x):
        x = self.flatten(x)
        return self.logistic(x)

class Agent:
    def __init__(self, env, player_name=None):
        self.env = env
        # Unwrap TimeLimit to get to the base environment
        while isinstance(self.env, TimeLimit):
            self.env = self.env.env
            
        # Get dimensions from the environment
        self.train_size = 70000
        self.test_size = 12000
        self.img_size = 28      # MNIST image size
        
        # Initialize shapes
        self.train_images_shape = (self.train_size, self.img_size, self.img_size)
        self.train_labels_shape = (self.train_size,)
        self.test_images_shape = (self.test_size, self.img_size, self.img_size)
        
        # Initialize logistic regression model
        self.model = MNISTLogisticRegression()
        self.criterion = nn.NLLLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        
    def train(self, X_train, y_tain):
        # Extract and process data
        # train_images, train_labels, test_images = self._extract_data(observation)
        
        # Convert to PyTorch tensors
        train_images = torch.FloatTensor(X_train)
        train_labels = torch.LongTensor(y_tain)
        
        # Train the model
        self.model.train()
        for _ in range(5):  # Multiple epochs per step
            self.optimizer.zero_grad()
            output = self.model(train_images)
            loss = self.criterion(output, train_labels)
            loss.backward()
            self.optimizer.step()

    def predict(self, X_test):
        test_images = torch.FloatTensor(X_test)
        # Predict test labels
        self.model.eval()
        with torch.no_grad():
            test_output = self.model(test_images)
            predictions = test_output.argmax(dim=1).numpy()
            
        return predictions

In [6]:

# Create the environment
env = gym.make('PermutedMNIST-v0', render_mode="rgb_array")

terminated = False
truncated = False
reward = 0

agent = Agent(env)
step = 0

observation, info = env.reset()

while not (terminated or truncated):
    # Sample random action from action space

    agent.train(observation['X_train'], observation['y_train'])

    #Y_pred = agent.predict(observation['X_test'])
    Y_pred = np.random.randint(0, 10, len(observation['X_test']))
    # Step the environment
    observation, reward, terminated, truncated, info = env.step(Y_pred)
    print(f"r:{reward}")
    step +=1
env.close()

init permuted_mnist env
loading mnist data
train_images loaded
train_labels loaded
test_images loaded
creating observation space
observation space created
resetting env
permutations created
datasets shuffled
labels permuted
pixels permuted
pixels permuted
getting obs


  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


stepping env
getting obs


  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


r:0.1001
stepping env
getting obs
r:0.0955
stepping env
getting obs
r:0.1013
stepping env
getting obs
r:0.1029
stepping env
getting obs
r:0.1026
stepping env
getting obs
r:0.1039
stepping env
getting obs
r:0.1055
stepping env
getting obs
r:0.0965
stepping env
getting obs
r:0.0956
stepping env
getting obs
r:0.1059
