In [2]:
# Generating an access JWT token for the API access

In [1]:
import os
import sys
import os
from datetime import timedelta, datetime, timezone

sys.path.append("..")

os.environ["CONFIG_PATH"] = "../config.yaml"
from config import Config, get_config
import jwt

In [2]:
cfg = get_config()
jwt_cfg = cfg.get("jwt")
print(jwt_cfg)

{'secret_key': '09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7', 'expiration': 2592000, 'algorithm': 'HS256'}


In [14]:
def create_access_token(data: dict, secret_key:str, algorithm: str, expires_delta: timedelta = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
    return encoded_jwt

SLIDING_EXPIRATION_THRESHOLD = 5 # 5 minutes

def verify_and_extend_token(token:str, secret_key: str, algorithm: str):
    """
    Verify the token and extend it if it's close to expiry
    Args: 
        token: The JWT token to verify
        secret_key: The secret key used to sign the token
        algorithm: The algorithm used to sign the token
    Returns:
        The original token if it's not close to expiry, or a new token if it is
    """
    try:
        payload = jwt.decode(token, secret_key, algorithms=[algorithm])
        exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
        remaining_time = (exp - datetime.now(timezone.utc)).total_seconds() / 60  # In minutes

        if remaining_time < SLIDING_EXPIRATION_THRESHOLD:
            print("Token near expiration, extending...")
            return create_access_token({"sub": payload["sub"]})
        
        return token  # Return the same token if it's not close to expiry
    except jwt.ExpiredSignatureError:
        raise ValueError("Token expired")
    except jwt.InvalidTokenError:
        raise ValueError("Invalid token")

In [9]:
acc_token = create_access_token({"sub": "test"}, jwt_cfg.get("secret_key"), jwt_cfg.get("algorithm"), expires_delta=timedelta(seconds=jwt_cfg.get("expiration", 900)))

In [10]:
acc_token

'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0IiwiZXhwIjoxNzQxMjk1NjIzfQ.5E4ts6UFZd0crNucisZZcI8oebS2M8AUaXAKairSHQA'

In [15]:
refreshed_token = verify_and_extend_token(acc_token, jwt_cfg.get("secret_key"), jwt_cfg.get("algorithm"))

In [16]:
refreshed_token

'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0IiwiZXhwIjoxNzQxMjk1NjIzfQ.5E4ts6UFZd0crNucisZZcI8oebS2M8AUaXAKairSHQA'