In [5]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt  # For plotting
import gym  # <-- ADD THIS IMPORT

# Load dataset
data = pd.read_csv("rl_documents.csv")
documents = [doc.split("|") for doc in data["document"]]
target_indices = data["target_index"].tolist()

class DocumentEnv:
    def __init__(self, documents, targets):
        self.documents = documents
        self.targets = targets
        self.current_doc = 0
        self.action_space = gym.spaces.Discrete(len(documents[self.current_doc]))  # Requires gym
        
    def reset(self):
        self.current_doc = np.random.randint(0, len(self.documents))
        return self.documents[self.current_doc]
    
    def step(self, action):
        correct_idx = self.targets[self.current_doc]
        reward = 1 if action == correct_idx else -1
        done = True
        return self.documents[self.current_doc], reward, done, {}

# Initialize environment
env = DocumentEnv(documents, target_indices)
q_table = np.zeros((len(documents), max(len(doc) for doc in documents)))

# Train the agent
for episode in range(100):
    state = env.reset()
    state_id = documents.index(state)
    action = np.argmax(q_table[state_id])
    _, reward, _, _ = env.step(action)
    q_table[state_id, action] += reward

# Test
test_doc = documents[0]
predicted_idx = np.argmax(q_table[0])
print(f"Document: {test_doc}")
print(f"Predicted target: {test_doc[predicted_idx]}")

Document: ['Invoice', 'ID', 'INV-768', 'Date', '2023-10-05', 'Total', '$500']
Predicted target: INV-768


In [9]:
# Test 2
test_doc = documents[1]
predicted_idx = np.argmax(q_table[0])
print(f"Document: {test_doc}")
print(f"Predicted target: {test_doc[predicted_idx]}")

Document: ['Receipt', 'ID', 'RC-2023', 'Shop', 'CoffeeCo', 'Amount', '$8.99']
Predicted target: RC-2023


In [12]:
# --- NEW: Count total documents and target indexes ---
total_docs = len(documents)
print(f"📑 Total Documents Processed: {total_docs}")

📑 Total Documents Processed: 5


In [14]:
# Count frequency of each target index
index_counts = data["target_index"].value_counts()
print("\n🎯 Target Index Frequencies:")
print(index_counts)


🎯 Target Index Frequencies:
target_index
2    2
6    2
5    1
Name: count, dtype: int64
