Skip to content

Commit

Permalink
ready to close #9
Browse files Browse the repository at this point in the history
  • Loading branch information
Filipe Pina committed Aug 4, 2015
1 parent 87ef8c6 commit 4875df8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 35 deletions.
15 changes: 14 additions & 1 deletion tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def list_commands(self):
TGCommandBase('read', self.read, 'read a note'),
TGCommandBase('savegroup', self.savegroup, 'save a group note'),
TGCommandBase('readgroup', self.readgroup, 'read a group note'),
TGCommandBase('prefixcmd', self.prefixcmd, 'prefix cmd', prefix=True)
TGCommandBase('prefixcmd', self.prefixcmd, 'prefix cmd', prefix=True, printable=False),
]

def echo_selective(self, bot, message, text):
Expand Down Expand Up @@ -124,6 +124,19 @@ def receive_message(self, text, sender=None, chat=None, reply_to_message_id=None

self.received_id += 1

def test_print_commands(self):
from cStringIO import StringIO
out = StringIO()
self.bot.print_commands(out=out)
self.assertEqual(out.getvalue(),'''\
read - read a note
echo - right back at ya
readgroup - read a group note
savegroup - save a group note
echo2 - right back at ya
save - save a note
''')

def test_reply(self):
self.receive_message('/echo test')
self.assertReplied(self.bot, 'test')
Expand Down
11 changes: 5 additions & 6 deletions tgbot/pluginbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@


class TGCommandBase(object):
def __init__(self, command, method, description='', prefix=False):
def __init__(self, command, method, description='', prefix=False, printable=True):
self.command = command
self.method = method
self.description = description
self.prefix = prefix
self.printable = printable

def printcommand(self):
print '%s - %s' % self.method, self.description
def __str__(self):
if self.printable:
return '%s - %s' % (self.command, self.description)


class TGPluginBase(object):
Expand All @@ -21,9 +23,6 @@ def __init__(self):
def list_commands(self):
'''
this method should return a list of TGCommandBase
Set command description to None (or '') to prevent that
command from being listed by TGBot.list_commands
'''
raise NotImplementedError('Abstract method')

Expand Down
67 changes: 39 additions & 28 deletions tgbot/tgbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .pluginbase import TGPluginBase, TGCommandBase
from playhouse.db_url import connect
import peewee
import sys


class TGBot(object):
Expand All @@ -11,6 +12,7 @@ def __init__(self, token, plugins=[], no_command=None, db_url=None):
self.tg = TelegramBot(token)
self._last_id = None
self.cmds = {}
self.pcmds = {}
self._no_cmd = no_command
self._msgs = {}
self._plugins = plugins
Expand All @@ -29,15 +31,19 @@ def __init__(self, token, plugins=[], no_command=None, db_url=None):
if not isinstance(cmd, TGCommandBase):
raise NotImplementedError('%s does not subclass tgbot.TGCommandBase' % type(cmd).__name__)

if cmd in self.cmds:
if cmd in self.cmds or cmd in self.pcmds:
raise Exception(
'Duplicate command %s: both in %s and %s' % (
cmd.command,
type(p).__name__,
self.cmds[cmd.command].description,
self.cmds.get(cmd.command) or self.pcmds.get(cmd.command),
)
)
self.cmds[cmd.command] = cmd

if cmd.prefix:
self.pcmds[cmd.command] = cmd
else:
self.cmds[cmd.command] = cmd

if db_url is None:
self.db = connect('sqlite:///:memory:')
Expand Down Expand Up @@ -95,9 +101,20 @@ def process_update(self, update): # noqa not complex at all!
pass # ignore, already exists

if message.text is not None and message.text.startswith('/'):
text = message.text.replace("@" + self.tg.username, '', 1)
for p in self._plugins:
self.process(p, message, text)
spl = message.text.find(' ')

if spl < 0:
cmd = message.text[1:]
text = ''
else:
cmd = message.text[1:spl]
text = message.text[spl+1:]

spl = cmd.find('@')
if spl > -1:
cmd = cmd[:spl]

self.process(message, cmd, text)
else:
was_expected = False
for p in self._plugins:
Expand Down Expand Up @@ -136,31 +153,25 @@ def run_web(self, hook_url, **kwargs):
run_server(self, **kwargs)

def list_commands(self):
out = []
for ck in self.cmds:
if self.cmds[ck].description:
out.append((ck, self.cmds[ck].description))
return out
return self.cmds.values() + self.pcmds.values()

def print_commands(self):
def print_commands(self, out=sys.stdout):
'''
utility method to print commands and descriptions
for @BotFather
utility method to print commands
and descriptions for @BotFather
'''
cmds = self.list_commands()
for ck in cmds:
print '%s - %s' % ck

def process(self, plugin, message, text):
for cmd in plugin.list_commands():
if text.startswith(cmd.command, 1):
if len(text) == (len(cmd.command) + 1):
cmd.method(self, message, '')
break
spl = text.find(' ')
if spl == (len(cmd.command) + 1):
cmd.method(self, message, text[spl + 1:])
break
if cmd.prefix:
cmd.method(self, message, text[len(cmd.command) + 1:])
if ck.printable:
out.write('%s\n' % ck)

def process(self, message, cmd, text):
if cmd in self.cmds:
self.cmds[cmd].method(self, message, text)
elif cmd in self.pcmds:
self.pcmds[cmd].method(self, message, text)
else:
for pcmd in self.pcmds:
if cmd.startswith(pcmd):
self.pcmds[pcmd].method(self, message, cmd[len(pcmd):] + text)
break

0 comments on commit 4875df8

Please sign in to comment.