/
trim.py
86 lines (80 loc) · 3.06 KB
/
trim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Tuple
from refinery.units import Arg, Unit
class trim(Unit):
"""
Removes byte sequences at beginning and end of input data.
"""
def __init__(
self, *junk: Arg(help='Binary strings to be removed, default are all whitespace characters.'),
unpad: Arg.Switch('-u', help='Also trim partial occurrences of the junk string.') = False,
left: Arg.Switch('-r', '--right-only', group='SIDE', help='Do not trim left.') = True,
right: Arg.Switch('-l', '--left-only', group='SIDE', help='Do not trim right.') = True,
nocase: Arg.Switch('-i', help='Ignore capitalization for alphabetic characters.') = False,
):
super().__init__(junk=junk, left=left, right=right, unpad=unpad, nocase=nocase)
def _trimfast(self, view: memoryview, *junks: bytes, right=False) -> Tuple[bool, memoryview]:
done = False
pos = 0
while not done:
done = True
for junk in junks:
temp = junk
size = len(junk)
if right and self.args.unpad:
for k in range(size):
n = size - k
if view[pos:pos + n] == junk[k:]:
pos += n
done = False
break
if view[pos:pos + size] == temp:
m = len(temp)
while True:
mm = m << 1
if view[pos + m:pos + mm] != temp:
break
temp += temp
m = mm
temp = memoryview(temp)
while m >= size:
if view[pos:pos + m] == temp[:m]:
done = False
pos += m
m //= 2
if right or not self.args.unpad:
continue
while size > 0:
if view[pos:pos + size] == temp[:size]:
done = False
pos += size
break
size -= 1
return pos
def process(self, data: bytearray):
junk = list(self.args.junk)
if not junk:
import string
space = string.whitespace.encode('ascii')
junk = [space[k - 1:k] for k in range(1, len(space))]
lpos = 0
rpos = 0
if self.args.nocase:
work = data.lower()
junk = [j.lower() for j in junk]
else:
work = data
if self.args.left:
lpos = self._trimfast(memoryview(work), *junk)
if self.args.right:
work.reverse()
junk = [bytes(reversed(j)) for j in junk]
rpos = self._trimfast(memoryview(work), *junk, right=True)
work.reverse()
view = memoryview(data)
if lpos:
view = view[+lpos:]
if rpos:
view = view[:-rpos]
return view