diff --git a/src/copilot-chat/config/copilot-chat.yaml b/src/copilot-chat/config/copilot-chat.yaml index 79a464d0..0b578ae2 100644 --- a/src/copilot-chat/config/copilot-chat.yaml +++ b/src/copilot-chat/config/copilot-chat.yaml @@ -32,8 +32,8 @@ llm-endpoint: "https://endpoint.openai.com" llm-model: "gpt-4o" llm-version: "2025-01-01-preview" embedding-model: "text-embedding-ada-002" -valid-groups: "admin,superuser" +valid-vcs: "superuser" environment: "prod" kusto-user-assigned-client-id: "" data-src-kusto-cluster-url: "https://placeholder.kusto.windows.net/" -data-src-kusto-database-name: "production" \ No newline at end of file +data-src-kusto-database-name: "production" diff --git a/src/copilot-chat/deploy/copilot-chat-deployment.yaml.template b/src/copilot-chat/deploy/copilot-chat-deployment.yaml.template index 809db270..6922acb8 100644 --- a/src/copilot-chat/deploy/copilot-chat-deployment.yaml.template +++ b/src/copilot-chat/deploy/copilot-chat-deployment.yaml.template @@ -75,8 +75,8 @@ spec: value: {{ cluster_cfg["copilot-chat"]["agent-host"] }} - name: RESTSERVER_URL value: {{ cluster_cfg["copilot-chat"]["rest-server"]["url"] }} - - name: COPILOT_VALID_GROUPS - value: {{ cluster_cfg["copilot-chat"]["valid-groups"] }} + - name: COPILOT_VALID_VCS + value: {{ cluster_cfg["copilot-chat"]["valid-vcs"] }} - name: ENVIRONMENT value: {{ cluster_cfg["copilot-chat"]["environment"] | default("prod") }} - name: KUSTO_USER_ASSIGNED_CLIENT_ID diff --git a/src/copilot-chat/src/copilot_agent/utils/authentication.py b/src/copilot-chat/src/copilot_agent/utils/authentication.py index 20c47612..784158f9 100644 --- a/src/copilot-chat/src/copilot_agent/utils/authentication.py +++ b/src/copilot-chat/src/copilot_agent/utils/authentication.py @@ -4,10 +4,13 @@ """Authentication Manager.""" import os +import requests +import urllib.parse from datetime import datetime, timezone from ..config import AGENT_MODE_LOCAL +from ..utils.logger import logger class AuthenticationManager: """Manages authentication state, expiration, and revocation for users.""" @@ -15,10 +18,14 @@ def __init__(self, expiration_ms: int = 3600000): self.authenticate_state = {} # username: {token, expires_at, groups} self.expiration_ms = expiration_ms self.restserver_url = os.getenv('RESTSERVER_URL', '') - valid_groups_env = os.getenv('COPILOT_VALID_GROUPS', 'admin,superuser') - self.valid_groups = [g.strip() for g in valid_groups_env.split(',') if g.strip()] + valid_vcs_env = os.getenv('COPILOT_VALID_VCS', 'admin,superuser') + self.valid_vcs = [g.strip() for g in valid_vcs_env.split(',') if g.strip()] - def authenticate(self, username: str, token: str) -> list: + def sanitize_username(self, username: str) -> str: + """Sanitize the username by URL-encoding it to prevent path traversal or injection attacks.""" + return urllib.parse.quote(username, safe='') + + def authenticate(self, username: str, token: str): """ Authenticate a user with token and REST server URL. @@ -27,28 +34,49 @@ def authenticate(self, username: str, token: str) -> list: token (str): The authentication token provided by the user. Returns: - list: A list of group names the user belongs to if authentication is successful. - Returns an empty list if authentication fails. + tuple: (admin, virtualCluster) """ if AGENT_MODE_LOCAL: - if username == "gooduser": - return ["admin"] + if username == "admin": + return True, [] + if username == "gooduser" or username == "dev.ben": + return False, ["superuser"] if username == "baduser": - return ["temp"] + return False, ["temp"] + # For any other username in local mode, return empty list + return False, [] else: - # TBD # This function should implement the logic to verify the user's token against the REST server (self.restserver_url). - return [] + try: + headers = { + 'Authorization': f'Bearer {token}' + } + username_sanitized = self.sanitize_username(username) + response = requests.get(f'{self.restserver_url}/api/v2/users/{username_sanitized}', headers=headers, timeout=5) + + if response.status_code == 200: + user_data = response.json() + # Extract groups from the response - adjust based on actual API response structure + is_admin = user_data.get('admin', False) + virtual_cluster = user_data.get('virtualCluster', []) + return is_admin, virtual_cluster + else: + logger.error(f"Authentication failed for user {username}: {response.status_code}") + return False, [] + except Exception as e: + logger.error(f"Error during authentication for user {username}: {e}") + return False, [] def set_authenticate_state(self, username: str, token: str) -> None: - """Set the authentication state for a user.""" + """Set the authentication state for a user, storing admin and virtualCluster info.""" expires_at = int(datetime.now(timezone.utc).timestamp() * 1000) + self.expiration_ms - groups = self.authenticate(username, token) - if groups: + is_admin, virtual_cluster = self.authenticate(username, token) + if is_admin is not None and virtual_cluster is not None: self.authenticate_state[username] = { 'token': token, 'expires_at': expires_at, - 'groups': groups + 'is_admin': is_admin, + 'virtual_cluster': virtual_cluster } else: self.revoke(username) @@ -61,14 +89,23 @@ def is_authenticated(self, username: str) -> bool: if state['expires_at'] < now: self.revoke(username) return False - if "groups" not in state: + if "is_admin" not in state: return False - if "groups" in state and not self.get_membership(state["groups"]): + if "virtual_cluster" not in state: return False - return True + if "is_admin" in state and "virtual_cluster" in state: + if state["is_admin"]: + # validate pass condition one: user is an admin + return True + elif not state["is_admin"] and self.get_membership(state["virtual_cluster"]): + # validate pass condition two: user is not an admin, but it belongs to a valid virtualCluster + return True + else: + return False + return False def get_membership(self, groups: list) -> bool: - return any(group in self.valid_groups for group in groups) + return any(group in self.valid_vcs for group in groups) def revoke(self, username: str): if username in self.authenticate_state: