Skip to content
Merged
4 changes: 2 additions & 2 deletions src/copilot-chat/config/copilot-chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ llm-endpoint: "https://endpoint.openai.com"
llm-model: "gpt-4o"
llm-version: "2025-01-01-preview"
embedding-model: "text-embedding-ada-002"
valid-groups: "admin,superuser"
valid-vcs: "superuser"
environment: "prod"
kusto-user-assigned-client-id: ""
data-src-kusto-cluster-url: "https://placeholder.kusto.windows.net/"
data-src-kusto-database-name: "production"
data-src-kusto-database-name: "production"
4 changes: 2 additions & 2 deletions src/copilot-chat/deploy/copilot-chat-deployment.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ spec:
value: {{ cluster_cfg["copilot-chat"]["agent-host"] }}
- name: RESTSERVER_URL
value: {{ cluster_cfg["copilot-chat"]["rest-server"]["url"] }}
- name: COPILOT_VALID_GROUPS
value: {{ cluster_cfg["copilot-chat"]["valid-groups"] }}
- name: COPILOT_VALID_VCS
value: {{ cluster_cfg["copilot-chat"]["valid-vcs"] }}
- name: ENVIRONMENT
value: {{ cluster_cfg["copilot-chat"]["environment"] | default("prod") }}
- name: KUSTO_USER_ASSIGNED_CLIENT_ID
Expand Down
73 changes: 55 additions & 18 deletions src/copilot-chat/src/copilot_agent/utils/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,28 @@
"""Authentication Manager."""

import os
import requests
import urllib.parse

from datetime import datetime, timezone

from ..config import AGENT_MODE_LOCAL
from ..utils.logger import logger

class AuthenticationManager:
"""Manages authentication state, expiration, and revocation for users."""
def __init__(self, expiration_ms: int = 3600000):
self.authenticate_state = {} # username: {token, expires_at, groups}
self.expiration_ms = expiration_ms
self.restserver_url = os.getenv('RESTSERVER_URL', '')
valid_groups_env = os.getenv('COPILOT_VALID_GROUPS', 'admin,superuser')
self.valid_groups = [g.strip() for g in valid_groups_env.split(',') if g.strip()]
valid_vcs_env = os.getenv('COPILOT_VALID_VCS', 'admin,superuser')
self.valid_vcs = [g.strip() for g in valid_vcs_env.split(',') if g.strip()]

def authenticate(self, username: str, token: str) -> list:
def sanitize_username(self, username: str) -> str:
"""Sanitize the username by URL-encoding it to prevent path traversal or injection attacks."""
return urllib.parse.quote(username, safe='')

def authenticate(self, username: str, token: str):
"""
Authenticate a user with token and REST server URL.

Expand All @@ -27,28 +34,49 @@ def authenticate(self, username: str, token: str) -> list:
token (str): The authentication token provided by the user.

Returns:
list: A list of group names the user belongs to if authentication is successful.
Returns an empty list if authentication fails.
tuple: (admin, virtualCluster)
"""
if AGENT_MODE_LOCAL:
if username == "gooduser":
return ["admin"]
if username == "admin":
return True, []
if username == "gooduser" or username == "dev.ben":
return False, ["superuser"]
if username == "baduser":
return ["temp"]
return False, ["temp"]
# For any other username in local mode, return empty list
return False, []
else:
# TBD
# This function should implement the logic to verify the user's token against the REST server (self.restserver_url).
return []
try:
headers = {
'Authorization': f'Bearer {token}'
}
username_sanitized = self.sanitize_username(username)
response = requests.get(f'{self.restserver_url}/api/v2/users/{username_sanitized}', headers=headers, timeout=5)

if response.status_code == 200:
user_data = response.json()
# Extract groups from the response - adjust based on actual API response structure
is_admin = user_data.get('admin', False)
virtual_cluster = user_data.get('virtualCluster', [])
return is_admin, virtual_cluster
else:
logger.error(f"Authentication failed for user {username}: {response.status_code}")
return False, []
except Exception as e:
logger.error(f"Error during authentication for user {username}: {e}")
return False, []

def set_authenticate_state(self, username: str, token: str) -> None:
"""Set the authentication state for a user."""
"""Set the authentication state for a user, storing admin and virtualCluster info."""
expires_at = int(datetime.now(timezone.utc).timestamp() * 1000) + self.expiration_ms
groups = self.authenticate(username, token)
if groups:
is_admin, virtual_cluster = self.authenticate(username, token)
if is_admin is not None and virtual_cluster is not None:
self.authenticate_state[username] = {
'token': token,
'expires_at': expires_at,
'groups': groups
'is_admin': is_admin,
'virtual_cluster': virtual_cluster
}
else:
self.revoke(username)
Expand All @@ -61,14 +89,23 @@ def is_authenticated(self, username: str) -> bool:
if state['expires_at'] < now:
self.revoke(username)
return False
if "groups" not in state:
if "is_admin" not in state:
return False
if "groups" in state and not self.get_membership(state["groups"]):
if "virtual_cluster" not in state:
return False
return True
if "is_admin" in state and "virtual_cluster" in state:
if state["is_admin"]:
# validate pass condition one: user is an admin
return True
elif not state["is_admin"] and self.get_membership(state["virtual_cluster"]):
# validate pass condition two: user is not an admin, but it belongs to a valid virtualCluster
return True
else:
return False
return False

def get_membership(self, groups: list) -> bool:
return any(group in self.valid_groups for group in groups)
return any(group in self.valid_vcs for group in groups)

def revoke(self, username: str):
if username in self.authenticate_state:
Expand Down