/
zzf.go
137 lines (121 loc) · 3.86 KB
/
zzf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package security
import (
"crypto/md5"
"errors"
"fmt"
"reflect"
"unsafe"
)
type ThirdPartCipher struct {
encryptType int // 外部加密算法id
encryptName string // 外部加密算法名称
hashType int
key []byte
cipherCount int // 外部加密算法个数
//innerId int // 外部加密算法内部id
blockSize int // 分组块大小
khSize int // key/hash大小
}
func NewThirdPartCipher(encryptType int, key []byte, cipherPath string, hashType int) (ThirdPartCipher, error) {
var tpc = ThirdPartCipher{
encryptType: encryptType,
key: key,
hashType: hashType,
cipherCount: -1,
}
var err error
err = initThirdPartCipher(cipherPath)
if err != nil {
return tpc, err
}
tpc.getCount()
if err = tpc.getInfo(); err != nil {
return tpc, err
}
return tpc, nil
}
func (tpc *ThirdPartCipher) getCount() int {
if tpc.cipherCount == -1 {
tpc.cipherCount = cipherGetCount()
}
return tpc.cipherCount
}
func (tpc *ThirdPartCipher) getInfo() error {
var cipher_id, ty, blk_size, kh_size int
//var strptr, _ = syscall.UTF16PtrFromString(tpc.encryptName)
var strptr *uint16 = new(uint16)
for i := 1; i <= tpc.getCount(); i++ {
cipherGetInfo(uintptr(i), uintptr(unsafe.Pointer(&cipher_id)), uintptr(unsafe.Pointer(&strptr)),
uintptr(unsafe.Pointer(&ty)), uintptr(unsafe.Pointer(&blk_size)), uintptr(unsafe.Pointer(&kh_size)))
if tpc.encryptType == cipher_id {
tpc.blockSize = blk_size
tpc.khSize = kh_size
tpc.encryptName = string(uintptr2bytes(uintptr(unsafe.Pointer(strptr))))
return nil
}
}
return fmt.Errorf("ThirdPartyCipher: cipher id:%d not found", tpc.encryptType)
}
func (tpc ThirdPartCipher) Encrypt(plaintext []byte, genDigest bool) []byte {
var tmp_para uintptr
cipherEncryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para)
ciphertextLen := cipherGetCipherTextSize(uintptr(tpc.encryptType), tmp_para, uintptr(len(plaintext)))
ciphertext := make([]byte, ciphertextLen)
ret := cipherEncrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)),
uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)))
ciphertext = ciphertext[:ret]
cipherClean(uintptr(tpc.encryptType), tmp_para)
// md5摘要
if genDigest {
digest := md5.Sum(plaintext)
encrypt := ciphertext
ciphertext = make([]byte, len(encrypt)+len(digest))
copy(ciphertext[:len(encrypt)], encrypt)
copy(ciphertext[len(encrypt):], digest[:])
}
return ciphertext
}
func (tpc ThirdPartCipher) Decrypt(ciphertext []byte, checkDigest bool) ([]byte, error) {
var ret []byte
if checkDigest {
var digest = ciphertext[len(ciphertext)-MD5_DIGEST_SIZE:]
ret = ciphertext[:len(ciphertext)-MD5_DIGEST_SIZE]
ret = tpc.decrypt(ret)
var msgDigest = md5.Sum(ret)
if !reflect.DeepEqual(msgDigest[:], digest) {
return nil, errors.New("Decrypt failed/Digest not match\n")
}
} else {
ret = tpc.decrypt(ciphertext)
}
return ret, nil
}
func (tpc ThirdPartCipher) decrypt(ciphertext []byte) []byte {
var tmp_para uintptr
cipherDecryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para)
plaintext := make([]byte, len(ciphertext))
ret := cipherDecrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)),
uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)))
plaintext = plaintext[:ret]
cipherClean(uintptr(tpc.encryptType), tmp_para)
return plaintext
}
func addBufSize(buf []byte, newCap int) []byte {
newBuf := make([]byte, newCap)
copy(newBuf, buf)
return newBuf
}
func uintptr2bytes(p uintptr) []byte {
buf := make([]byte, 64)
i := 0
for b := (*byte)(unsafe.Pointer(p)); *b != 0; i++ {
if i > cap(buf) {
buf = addBufSize(buf, i*2)
}
buf[i] = *b
// byte占1字节
p++
b = (*byte)(unsafe.Pointer(p))
}
return buf[:i]
}