In [None]:
%load_ext autoreload
%autoreload 2
# default_exp plugin.authenticated

In [None]:
#export
from pymemri.data.schema import Item, CVUStoredDefinition, Account
from pymemri.plugin.stateful import StatefulPlugin
import abc
import logging
from time import time, sleep

In [None]:
#export
RUN_STATE_POLLING_INTERVAL = 0.6
RUN_USER_ACTION_TIMEOUT = 120

## Authentication methods for plugins

In [None]:
# export
class AuthenticatedPlugin(StatefulPlugin):

    def __init__(self, runId=None, **kwargs):
        super().__init__(runId=runId, **kwargs)

    @abc.abstractmethod
    def start_auth(self, client):
        raise NotImplementedError()

    def get_account_from_plugin(self, client, pluginName=None):
        # Find persistent plugin state
        persistent_state = self.get_state(client, pluginName)
        return persistent_state.get_account()  

    def ask_user_for_accounts(self, client, view, oauth_url=None):
        # start userActionNeeded flow
        self.action_required(client)
        
        self.set_run_vars(client, {'oAuthUrl': oauth_url})
        self.set_run_view(client, view)

        # poll here
        start_time = time()
        # handle timeouts
        while RUN_USER_ACTION_TIMEOUT > time() - start_time:
            sleep(RUN_STATE_POLLING_INTERVAL)
            if self.is_action_completed(client):
                # Now the client has set up the account as an edge to the plugin
                return self.get_account_from_plugin(client)

        raise Exception("PluginFlow: User input timeout")

    def set_account_vars(self, client, vars_dictionary):
        account = self.get_account_from_plugin(client)
        if account:
            for k,v in vars_dictionary.items():
                setattr(account, k, v)
            account.update(client)
        else:
            # Create account item
            account = Account(**vars_dictionary)
            # Save accounts as an edge to the plugin
            client.create(account)
            # add the account to the plugin item
            plugin = client.get(self.persistenceId)
            plugin.add_edge('account', account)
            plugin.update(client)

    def add_to_schema(self, client):
        super().add_to_schema(client)
        client.add_to_schema(Account(identifier="", secret="", code="", refreshToken="", \
            service="", handle="", displayName="", avatarUrl="", externalId="", errorMessage=""))


## Building a plugin with authentication

In [None]:
NUM_LOGIN_TRIES = 3
SERVICE_NAME = "example service"

class MyAuthPlugin(AuthenticatedPlugin):

    def __init__(self, runId=None, **kwargs):
        super().__init__(runId=runId, **kwargs)

    def start_auth(self, client):
        account = self.get_account_from_plugin(client)
        if account and not account.errorMessage:
            return True
        
        for i in range(NUM_LOGIN_TRIES):
            account = self.ask_user_for_accounts(client, "auth-view", oauth_url="")
            # test user input here, normally via a service e.g. `self.service.login(account.identifier, account.secret)`
            if account and account.identifier == "username" and account.secret == "password":
                return True
            else:
                self.set_account_vars(client, {'errorMessage': 'Incorrect credentials'})
            
        return False


    def run(self):
        print("Running plugin")
        pass
    
    def add_to_schema(self, client):
        print("Adding schema")
        super().add_to_schema(client)
        # add plugin-specific schemas here
        pass


In [None]:
from pymemri.pod.client import PodClient
from pymemri.data.schema import CVUStoredDefinition
from pymemri.plugin.pluginbase import PluginRun, register_base_schemas
from pymemri.plugin.stateful import RUN_STARTED

# prepare
client = PodClient()

# deploy plugin or get the deployed state
auth_plugin = MyAuthPlugin()
auth_plugin.add_to_schema(client)
auth_plugin.init_run(client, containerImage="authenticated_plugin", pluginModule="pymemri.plugin.authenticated", pluginName="AuthenticatedPlugin")
persistence = auth_plugin.get_state(client, "myAuthPlugin")
if not persistence:
    views = [CVUStoredDefinition(name="login-view"), CVUStoredDefinition(name="other-view")]
    persistence = auth_plugin.persist(client, "myAuthPlugin", views=views, account=None)

# test the run states
auth_plugin.started(client)
assert auth_plugin.get_run_state_str(client) == RUN_STARTED

# test empty persistent account
state = auth_plugin.get_state(client)
assert state.get_account() == None

# set an account attached to the persistent state
service_account = Account(service="3rd party", identifier="me", secret="mypassword")
service_account.update(client)
state = auth_plugin.get_state(client)
state.set_account(client, service_account)
assert state.get_account().identifier == "me"
