diff --git a/block/miniblock.go b/block/miniblock.go index dc61171b..9f4279a6 100644 --- a/block/miniblock.go +++ b/block/miniblock.go @@ -199,12 +199,20 @@ func (mbl *MiniBlock) Deserialize(buf []byte) (err error) { copy(mbl.Nonce[:], buf[15+16+32:]) mbl.Height = int64(binary.BigEndian.Uint64(mbl.Check[:])) - if mbl.GetMiniID() == mbl.Past[0] { - return fmt.Errorf("Self Collision") + return +} + +// checks for basic sanity +func (mbl *MiniBlock) IsSafe() bool { + id := mbl.GetMiniID() + if id == mbl.Past[0] { + //return fmt.Errorf("Self Collision") + return false } - if mbl.PastCount == 2 && mbl.GetMiniID() == mbl.Past[1] { - return fmt.Errorf("Self Collision") + if mbl.PastCount == 2 && id == mbl.Past[1] { + //return fmt.Errorf("Self Collision") + return false } - return + return true } diff --git a/block/miniblockdag.go b/block/miniblockdag.go index aa8c368a..e4b58dc0 100644 --- a/block/miniblockdag.go +++ b/block/miniblockdag.go @@ -119,7 +119,7 @@ func (c *MiniBlocksCollection) Get(id uint32) (mbl MiniBlock) { var ok bool if mbl, ok = c.Collection[id]; !ok { - panic("past should be present") + panic("id requested should be present") } return mbl } diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index d14b5f66..f67b1422 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -21,6 +21,7 @@ package blockchain // We must not call any packages that can call panic // NO Panics or FATALs please +import "os" import "fmt" import "sync" import "time" @@ -68,9 +69,13 @@ type Blockchain struct { cache_IsNonceValidTips *lru.Cache // used to cache nonce tests on specific tips cache_IsAddressHashValid *lru.Cache // used to cache some outputs cache_Get_Difficulty_At_Tips *lru.Cache // used to cache some outputs + cache_BlockPast *lru.Cache // used to cache a blocks past + cache_BlockHeight *lru.Cache // used to cache a blocks past integrator_address rpc.Address // integrator rewards will be given to this address + cache_disabled bool // disables all cache, based on ENV DISABLE_CACHE + Difficulty uint64 // current cumulative difficulty Median_Block_Size uint64 // current median block size Mempool *mempool.Mempool // normal tx pool @@ -83,7 +88,7 @@ type Blockchain struct { simulator bool // is simulator mode P2P_Block_Relayer func(*block.Complete_Block, uint64) // tell p2p to broadcast any block this daemon hash found - P2P_MiniBlock_Relayer func(mbl block.MiniBlock, peerid uint64) + P2P_MiniBlock_Relayer func(mbl []block.MiniBlock, peerid uint64) RPC_NotifyNewBlock *sync.Cond // used to notify rpc that a new block has been found RPC_NotifyHeightChanged *sync.Cond // used to notify rpc that chain height has changed due to addition of block @@ -122,7 +127,6 @@ func Blockchain_Start(params map[string]interface{}) (*Blockchain, error) { if addr, err = rpc.NewAddress(strings.TrimSpace(globals.Config.Dev_Address)); err != nil { return nil, err } - } else { if addr, err = rpc.NewAddress(strings.TrimSpace(params["--integrator-address"].(string))); err != nil { return nil, err @@ -149,6 +153,19 @@ func Blockchain_Start(params map[string]interface{}) (*Blockchain, error) { return nil, err } + if chain.cache_BlockPast, err = lru.New(100 * 1024); err != nil { // temporary cache for a blocks past + return nil, err + } + + if chain.cache_BlockHeight, err = lru.New(100 * 1024); err != nil { // temporary cache for a blocks height + return nil, err + } + + chain.cache_disabled = os.Getenv("DISABLE_CACHE") != "" // disable cache if the environ var is set + if chain.cache_disabled { + logger.Info("All caching except mining jobs will be disabled") + } + if params["--simulator"] == true { chain.simulator = true // enable simulator mode, this will set hard coded difficulty to 1 } @@ -197,10 +214,56 @@ func Blockchain_Start(params map[string]interface{}) (*Blockchain, error) { } } - go clean_up_valid_cache() // clean up valid cache - atomic.AddUint32(&globals.Subsystem_Active, 1) // increment subsystem + globals.Cron.AddFunc("@every 360s", clean_up_valid_cache) // cleanup valid tx cache + globals.Cron.AddFunc("@every 60s", func() { // mempool house keeping + + stable_height := int64(0) + if r := recover(); r != nil { + logger.Error(nil, "Mempool House Keeping triggered panic", "r", r, "height", stable_height) + } + + stable_height = chain.Get_Stable_Height() + + // give mempool an oppurtunity to clean up tx, but only if they are not mined + chain.Mempool.HouseKeeping(uint64(stable_height)) + + top_block_topo_index := chain.Load_TOPO_HEIGHT() + + if top_block_topo_index < 10 { + return + } + + top_block_topo_index -= 10 + + blid, err := chain.Load_Block_Topological_order_at_index(top_block_topo_index) + if err != nil { + panic(err) + } + + record_version, err := chain.ReadBlockSnapshotVersion(blid) + if err != nil { + panic(err) + } + + // give regpool a chance to register + if ss, err := chain.Store.Balance_store.LoadSnapshot(record_version); err == nil { + if balance_tree, err := ss.GetTree(config.BALANCE_TREE); err == nil { + chain.Regpool.HouseKeeping(uint64(stable_height), func(tx *transaction.Transaction) bool { + if tx.TransactionType != transaction.REGISTRATION { // tx not registration so delete + return true + } + if _, err := balance_tree.Get(tx.MinerAddress[:]); err != nil { // address already registered + return true + } + return false // account not already registered, so give another chance + }) + + } + } + }) + return &chain, nil } @@ -209,6 +272,45 @@ func (chain *Blockchain) IntegratorAddress() rpc.Address { return chain.integrator_address } +// this function is called to read blockchain state from DB +// It is callable at any point in time + +func (chain *Blockchain) Initialise_Chain_From_DB() { + chain.Lock() + defer chain.Unlock() + + chain.Pruned = chain.LocatePruneTopo() + + // find the tips from the chain , first by reaching top height + // then downgrading to top-10 height + // then reworking the chain to get the tip + best_height := chain.Load_TOP_HEIGHT() + chain.Height = best_height + + chain.Tips = map[crypto.Hash]crypto.Hash{} // reset the map + // reload top tip from disk + top := chain.Get_Top_ID() + + chain.Tips[top] = top // we only can load a single tip from db + + logger.V(1).Info("Reloaded Chain from disk", "Tips", chain.Tips, "Height", chain.Height) +} + +// before shutdown , make sure p2p is confirmed stopped +func (chain *Blockchain) Shutdown() { + + chain.Lock() // take the lock as chain is no longer in unsafe mode + close(chain.Exit_Event) // send signal to everyone we are shutting down + + chain.Mempool.Shutdown() // shutdown mempool first + chain.Regpool.Shutdown() // shutdown regpool first + + logger.Info("Stopping Blockchain") + //chain.Store.Shutdown() + atomic.AddUint32(&globals.Subsystem_Active, ^uint32(0)) // this decrement 1 fom subsystem + logger.Info("Stopped Blockchain") +} + // this is the only entrypoint for new / old blocks even for genesis block // this will add the entire block atomically to the chain // this is the only function which can add blocks to the chain @@ -254,7 +356,7 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro } if result == true { // block was successfully added, commit it atomically - logger.V(2).Info("Block successfully accepted by chain", "blid", block_hash.String()) + logger.V(2).Info("Block successfully accepted by chain", "blid", block_hash.String(), "err", err) // gracefully try to instrument func() { @@ -542,7 +644,11 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro var history_array []crypto.Hash for i := range bl.Tips { - history_array = append(history_array, chain.get_ordered_past(bl.Tips[i], 26)...) + h := int64(bl.Height) - 25 + if h < 0 { + h = 0 + } + history_array = append(history_array, chain.get_ordered_past(bl.Tips[i], h)...) } for _, h := range history_array { history[h] = true @@ -683,7 +789,7 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro } - { + if height_changed { var full_order []crypto.Hash var base_topo_index int64 // new topo id will start from here @@ -699,15 +805,18 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro // we will directly use graviton to mov in to history logger.V(3).Info("Full order data", "full_order", full_order, "base_topo_index", base_topo_index) + if base_topo_index < 0 { + logger.Error(nil, "negative base topo, not possible, probably disk corruption or core issue") + os.Exit(0) + } + topos_written := false for i := int64(0); i < int64(len(full_order)); i++ { logger.V(3).Info("will execute order ", "i", i, "blid", full_order[i].String()) current_topo_block := i + base_topo_index - previous_topo_block := current_topo_block - 1 - - _ = previous_topo_block + //previous_topo_block := current_topo_block - 1 - if current_topo_block == chain.Load_Block_Topological_order(full_order[i]) { // skip if same order + if !topos_written && current_topo_block == chain.Load_Block_Topological_order(full_order[i]) { // skip if same order continue } @@ -738,8 +847,6 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro } var balance_tree, sc_meta *graviton.Tree - _ = sc_meta - var ss *graviton.Snapshot if bl_current.Height == 0 { // if it's genesis block if ss, err = chain.Store.Balance_store.LoadSnapshot(0); err != nil { @@ -780,7 +887,10 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro //chain.Store.Topo_store.Write(i+base_topo_index, full_order[i],0, int64(bl_current.Height)) // write entry so as sideblock could work var data_trees []*graviton.Tree - if !chain.isblock_SideBlock_internal(full_order[i], current_topo_block, int64(bl_current.Height)) { + if chain.isblock_SideBlock_internal(full_order[i], current_topo_block, int64(bl_current.Height)) { + logger.V(3).Info("this block is a side block", "height", chain.Load_Block_Height(full_order[i]), "blid", full_order[i]) + } else { + logger.V(3).Info("this block is a full block", "height", chain.Load_Block_Height(full_order[i]), "blid", full_order[i]) sc_change_cache := map[crypto.Hash]*graviton.Tree{} // cache entire changes for entire block @@ -847,9 +957,6 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro } chain.process_miner_transaction(bl_current, bl_current.Height == 0, balance_tree, fees_collected, bl_current.Height) - } else { - block_logger.V(1).Info("this block is a side block", "height", chain.Load_Block_Height(full_order[i]), "blid", full_order[i]) - } // we are here, means everything is okay, lets commit the update balance tree @@ -860,16 +967,15 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro panic(err) } - chain.StoreBlock(bl, commit_version) - if height_changed { - chain.Store.Topo_store.Write(current_topo_block, full_order[i], commit_version, chain.Load_Block_Height(full_order[i])) - if logger.V(3).Enabled() { - merkle_root, err := chain.Load_Merkle_Hash(commit_version) - if err != nil { - panic(err) - } - logger.V(3).Info("height changed storing topo", "i", i, "blid", full_order[i].String(), "topoheight", current_topo_block, "commit_version", commit_version, "committed_merkle", merkle_root) + chain.StoreBlock(bl_current, commit_version) + topos_written = true + chain.Store.Topo_store.Write(current_topo_block, full_order[i], commit_version, chain.Load_Block_Height(full_order[i])) + if logger.V(3).Enabled() { + merkle_root, err := chain.Load_Merkle_Hash(commit_version) + if err != nil { + panic(err) } + logger.V(3).Info("storing topo", "i", i, "blid", full_order[i].String(), "topoheight", current_topo_block, "commit_version", commit_version, "committed_merkle", merkle_root) } } @@ -921,111 +1027,18 @@ func (chain *Blockchain) Add_Complete_Block(cbl *block.Complete_Block) (err erro block_logger.Info(fmt.Sprintf("Chain Height %d", chain.Height)) } + purge_count := chain.MiniBlocks.PurgeHeight(chain.Get_Stable_Height()) // purge all miniblocks upto this height + logger.V(2).Info("Purged miniblock", "count", purge_count) + result = true // TODO fix hard fork // maintain hard fork votes to keep them SANE //chain.Recount_Votes() // does not return anything - // enable mempool book keeping - - func() { - if r := recover(); r != nil { - logger.Error(nil, "Mempool House Keeping triggered panic", "r", r, "height", block_height) - } - - purge_count := chain.MiniBlocks.PurgeHeight(chain.Get_Stable_Height()) // purge all miniblocks upto this height - logger.V(2).Info("Purged miniblock", "count", purge_count) - - // discard the transactions from mempool if they are present there - chain.Mempool.Monitor() - - for i := 0; i < len(cbl.Txs); i++ { - txid := cbl.Txs[i].GetHash() - - switch cbl.Txs[i].TransactionType { - - case transaction.REGISTRATION: - if chain.Regpool.Regpool_TX_Exist(txid) { - logger.V(3).Info("Deleting TX from regpool", "txid", txid) - chain.Regpool.Regpool_Delete_TX(txid) - continue - } - - case transaction.NORMAL, transaction.BURN_TX, transaction.SC_TX: - if chain.Mempool.Mempool_TX_Exist(txid) { - logger.V(3).Info("Deleting TX from mempool", "txid", txid) - chain.Mempool.Mempool_Delete_TX(txid) - continue - } - - } - - } - - // give mempool an oppurtunity to clean up tx, but only if they are not mined - chain.Mempool.HouseKeeping(uint64(block_height)) - - // give regpool a chance to register - if ss, err := chain.Store.Balance_store.LoadSnapshot(0); err == nil { - if balance_tree, err := ss.GetTree(config.BALANCE_TREE); err == nil { - - chain.Regpool.HouseKeeping(uint64(block_height), func(tx *transaction.Transaction) bool { - if tx.TransactionType != transaction.REGISTRATION { // tx not registration so delete - return true - } - if _, err := balance_tree.Get(tx.MinerAddress[:]); err != nil { // address already registered - return true - } - return false // account not already registered, so give another chance - }) - - } - } - - }() - return // run any handlers necesary to atomically } -// this function is called to read blockchain state from DB -// It is callable at any point in time - -func (chain *Blockchain) Initialise_Chain_From_DB() { - chain.Lock() - defer chain.Unlock() - - chain.Pruned = chain.LocatePruneTopo() - - // find the tips from the chain , first by reaching top height - // then downgrading to top-10 height - // then reworking the chain to get the tip - best_height := chain.Load_TOP_HEIGHT() - chain.Height = best_height - - chain.Tips = map[crypto.Hash]crypto.Hash{} // reset the map - // reload top tip from disk - top := chain.Get_Top_ID() - - chain.Tips[top] = top // we only can load a single tip from db - - logger.V(1).Info("Reloaded Chain from disk", "Tips", chain.Tips, "Height", chain.Height) -} - -// before shutdown , make sure p2p is confirmed stopped -func (chain *Blockchain) Shutdown() { - - chain.Lock() // take the lock as chain is no longer in unsafe mode - close(chain.Exit_Event) // send signal to everyone we are shutting down - - chain.Mempool.Shutdown() // shutdown mempool first - chain.Regpool.Shutdown() // shutdown regpool first - - logger.Info("Stopping Blockchain") - //chain.Store.Shutdown() - atomic.AddUint32(&globals.Subsystem_Active, ^uint32(0)) // this decrement 1 fom subsystem -} - // get top unstable height // this is obtained by getting the highest topo block and getting its height func (chain *Blockchain) Get_Height() int64 { @@ -1433,73 +1446,57 @@ func (chain *Blockchain) IsBlockSyncBlockHeightSpecific(blid crypto.Hash, chain_ // converts a DAG's partial order into a full order, this function is recursive // dag can be processed only one height at a time // blocks are ordered recursively, till we find a find a block which is already in the chain +// this could be done via binary search also, but this is also easy func (chain *Blockchain) Generate_Full_Order_New(current_tip crypto.Hash, new_tip crypto.Hash) (order []crypto.Hash, topo int64) { - /*if !(chain.Load_Height_for_BL_ID(new_tip) == chain.Load_Height_for_BL_ID(current_tip)+1 || - chain.Load_Height_for_BL_ID(new_tip) == chain.Load_Height_for_BL_ID(current_tip)) { - panic("dag can only grow one height at a time") - }*/ - - depth := 20 - for ; ; depth += 20 { - current_history := chain.get_ordered_past(current_tip, depth) - new_history := chain.get_ordered_past(new_tip, depth) - - if len(current_history) < 5 { // we assume chain will not fork before 4 blocks - var current_history_rev []crypto.Hash - var new_history_rev []crypto.Hash + start := time.Now() + defer logger.V(2).Info("generating full order", "took", time.Now().Sub(start)) - for i := range current_history { - current_history_rev = append(current_history_rev, current_history[len(current_history)-i-1]) - } - for i := range new_history { - new_history_rev = append(new_history_rev, new_history[len(new_history)-i-1]) - } + matchtill := chain.Load_Height_for_BL_ID(new_tip) + step_size := int64(10) - for j := range new_history_rev { - found := false - for i := range current_history_rev { - if current_history_rev[i] == new_history_rev[j] { - found = true - break - } - } + for { + matchtill -= step_size + if matchtill < 0 { + matchtill = 0 + } + current_history := chain.get_ordered_past(current_tip, matchtill) + new_history := chain.get_ordered_past(new_tip, matchtill) - if !found { // we have a contention point - topo = chain.Load_Block_Topological_order(new_history_rev[j-1]) - order = append(order, new_history_rev[j-1:]...) // order is already stored and store - return - } + if matchtill == 0 { + if current_history[0] != new_history[0] { + panic("genesis not matching") } - panic("not possible") + topo = 0 + order = append(order, new_history...) + return } - for i := 0; i < len(current_history)-4; i++ { - for j := 0; j < len(new_history)-4; j++ { - if current_history[i+0] == new_history[j+0] && - current_history[i+1] == new_history[j+1] && - current_history[i+2] == new_history[j+2] && - current_history[i+3] == new_history[j+3] { + if current_history[0] != new_history[0] { // base are not matching, step back further + continue + } - topo = chain.Load_Block_Topological_order(new_history[j]) - for k := j; k >= 0; k-- { - order = append(order, new_history[k]) // reverse order and store - } - return + if current_history[0] != new_history[0] || + current_history[1] != new_history[1] || + current_history[2] != new_history[2] || + current_history[3] != new_history[3] { - } - } + continue // base are not matching, step back further } + + order = append(order, new_history[:]...) + topo = chain.Load_Block_Topological_order(order[0]) + return } return } // we will collect atleast 50 blocks or till genesis -func (chain *Blockchain) get_ordered_past(tip crypto.Hash, count int) (order []crypto.Hash) { +func (chain *Blockchain) get_ordered_past(tip crypto.Hash, tillheight int64) (order []crypto.Hash) { order = append(order, tip) current := tip - for len(order) < count { + for chain.Load_Height_for_BL_ID(current) > tillheight { past := chain.Get_Block_Past(current) switch len(past) { @@ -1520,5 +1517,9 @@ func (chain *Blockchain) get_ordered_past(tip crypto.Hash, count int) (order []c panic("data corruption") } } + + for i, j := 0, len(order)-1; i < j; i, j = i+1, j-1 { + order[i], order[j] = order[j], order[i] + } return } diff --git a/blockchain/difficulty.go b/blockchain/difficulty.go index 26d18093..76cdc1ff 100644 --- a/blockchain/difficulty.go +++ b/blockchain/difficulty.go @@ -212,7 +212,9 @@ func (chain *Blockchain) Get_Difficulty_At_Tips(tips []crypto.Hash) *big.Int { biggest_difficulty.Set(MinimumDifficulty) } - chain.cache_Get_Difficulty_At_Tips.Add(tips_string, string(biggest_difficulty.Bytes())) // set in cache + if !chain.cache_disabled { + chain.cache_Get_Difficulty_At_Tips.Add(tips_string, string(biggest_difficulty.Bytes())) // set in cache + } return biggest_difficulty } @@ -235,7 +237,9 @@ func (chain *Blockchain) VerifyMiniblockPoW(bl *block.Block, mbl block.MiniBlock }*/ if CheckPowHashBig(PoW, block_difficulty) == true { - chain.cache_IsMiniblockPowValid.Add(fmt.Sprintf("%s", cachekey), true) // set in cache + if !chain.cache_disabled { + chain.cache_IsMiniblockPowValid.Add(fmt.Sprintf("%s", cachekey), true) // set in cache + } return true } return false diff --git a/blockchain/mempool/mempool.go b/blockchain/mempool/mempool.go index 1855feb1..22468d15 100644 --- a/blockchain/mempool/mempool.go +++ b/blockchain/mempool/mempool.go @@ -21,8 +21,6 @@ import "sync" import "sort" import "time" import "sync/atomic" -import "encoding/hex" -import "encoding/json" import "github.com/go-logr/logr" @@ -55,9 +53,6 @@ type Mempool struct { modified bool // used to monitor whethel mem pool contents have changed, height uint64 // track blockchain height - P2P_TX_Relayer p2p_TX_Relayer // actual pointer, setup by the dero daemon during runtime - - relayer chan crypto.Hash // used for immediate relay // global variable , but don't see it utilisation here except fot tx verification //chain *Blockchain Exit_Mutex chan bool @@ -70,65 +65,12 @@ type mempool_object struct { Tx *transaction.Transaction Added uint64 // time in epoch format Height uint64 // at which height the tx unlocks in the mempool - Relayed int // relayed count - RelayedAt int64 // when was tx last relayed Size uint64 // size in bytes of the TX FEEperBYTE uint64 // fee per byte } var loggerpool logr.Logger -// marshal object as json -func (obj *mempool_object) MarshalJSON() ([]byte, error) { - return json.Marshal(&struct { - Tx string `json:"tx"` // hex encoding - Added uint64 `json:"added"` - Height uint64 `json:"height"` - Relayed int `json:"relayed"` - RelayedAt int64 `json:"relayedat"` - }{ - Tx: hex.EncodeToString(obj.Tx.Serialize()), - Added: obj.Added, - Height: obj.Height, - Relayed: obj.Relayed, - RelayedAt: obj.RelayedAt, - }) -} - -// unmarshal object from json encoding -func (obj *mempool_object) UnmarshalJSON(data []byte) error { - aux := &struct { - Tx string `json:"tx"` - Added uint64 `json:"added"` - Height uint64 `json:"height"` - Relayed int `json:"relayed"` - RelayedAt int64 `json:"relayedat"` - }{} - - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - obj.Added = aux.Added - obj.Height = aux.Height - obj.Relayed = aux.Relayed - obj.RelayedAt = aux.RelayedAt - - tx_bytes, err := hex.DecodeString(aux.Tx) - if err != nil { - return err - } - obj.Size = uint64(len(tx_bytes)) - - obj.Tx = &transaction.Transaction{} - err = obj.Tx.Deserialize(tx_bytes) - - if err == nil { - obj.FEEperBYTE = obj.Tx.Fees() / obj.Size - } - return err -} - func Init_Mempool(params map[string]interface{}) (*Mempool, error) { var mempool Mempool //mempool.chain = params["chain"].(*Blockchain) @@ -137,7 +79,6 @@ func Init_Mempool(params map[string]interface{}) (*Mempool, error) { loggerpool.Info("Mempool started") atomic.AddUint32(&globals.Subsystem_Active, 1) // increment subsystem - mempool.relayer = make(chan crypto.Hash, 1024*10) mempool.Exit_Mutex = make(chan bool) metrics.Set.GetOrCreateGauge("mempool_count", func() float64 { @@ -152,20 +93,6 @@ func Init_Mempool(params map[string]interface{}) (*Mempool, error) { return &mempool, nil } -// this is created per incoming block and then discarded -// This does not require shutting down and will be garbage collected automatically -/* -func Init_Block_Mempool(params map[string]interface{}) (*Mempool, error) { - var mempool Mempool - - // initialize maps - //mempool.txs = map[crypto.Hash]*mempool_object{} - //mempool.nonces = map[crypto.Hash]bool{} - - return &mempool, nil -} -*/ - func (pool *Mempool) HouseKeeping(height uint64) { pool.height = height @@ -257,8 +184,6 @@ func (pool *Mempool) Mempool_Add_TX(tx *transaction.Transaction, Height uint64) object.FEEperBYTE = tx.Fees() / object.Size pool.txs.Store(tx_hash, &object) - - pool.relayer <- tx_hash pool.modified = true // pool has been modified //pool.sort_list() // sort and update pool list @@ -402,12 +327,12 @@ func (pool *Mempool) Mempool_Print() { }) loggerpool.Info(fmt.Sprintf("Total TX in mempool = %d\n", len(klist))) - loggerpool.Info(fmt.Sprintf("%20s %14s %7s %7s %6s %32s\n", "Added", "Last Relayed", "Relayed", "Size", "Height", "TXID")) + loggerpool.Info(fmt.Sprintf("%20s %14s %7s %7s %6s %32s\n", "Added", "Size", "Height", "TXID")) for i := range klist { k := klist[i] v := vlist[i] - loggerpool.Info(fmt.Sprintf("%20s %14s %7d %7d %6d %32s\n", time.Unix(int64(v.Added), 0).UTC().Format(time.RFC3339), time.Duration(v.RelayedAt)*time.Second, v.Relayed, + loggerpool.Info(fmt.Sprintf("%20s %14s %7d %7d %6d %32s\n", time.Unix(int64(v.Added), 0).UTC().Format(time.RFC3339), len(v.Tx.Serialize()), v.Height, k)) } } @@ -461,5 +386,3 @@ func (pool *Mempool) sort_list() ([]crypto.Hash, []TX_Sorting_struct) { return sorted_list, data } - -type p2p_TX_Relayer func(*transaction.Transaction, uint64) int // function type, exported in p2p but cannot use due to cyclic dependency diff --git a/blockchain/miner_block.go b/blockchain/miner_block.go index 7866580a..fe37a6b3 100644 --- a/blockchain/miner_block.go +++ b/blockchain/miner_block.go @@ -523,7 +523,17 @@ func (chain *Blockchain) Accept_new_block(tstamp uint64, miniblock_blob []byte) // notify peers, we have a miniblock and return to miner if !chain.simulator { // if not in simulator mode, relay miniblock to the chain - go chain.P2P_MiniBlock_Relayer(mbl, 0) + var mbls []block.MiniBlock + + if !mbl.Genesis { + for i := uint8(0); i < mbl.PastCount; i++ { + mbls = append(mbls, chain.MiniBlocks.Get(mbl.Past[i])) + } + + } + mbls = append(mbls, mbl) + go chain.P2P_MiniBlock_Relayer(mbls, 0) + } // if a duplicate block is being sent, reject the block @@ -665,7 +675,9 @@ hard_way: if err != nil || bits >= 120 { return } - chain.cache_IsAddressHashValid.Add(fmt.Sprintf("%s", hash), true) // set in cache + if !chain.cache_disabled { + chain.cache_IsAddressHashValid.Add(fmt.Sprintf("%s", hash), true) // set in cache + } } return true diff --git a/blockchain/miniblocks_consensus.go b/blockchain/miniblocks_consensus.go index 229f254d..64b73e09 100644 --- a/blockchain/miniblocks_consensus.go +++ b/blockchain/miniblocks_consensus.go @@ -80,6 +80,9 @@ func (chain *Blockchain) Verify_MiniBlocks(bl block.Block) (err error) { // check whether the genesis blocks are all equal for _, mbl := range bl.MiniBlocks { + if !mbl.IsSafe() { + return fmt.Errorf("MiniBlock is unsafe") + } if mbl.Genesis { // make sure all genesis blocks point to all the actual tips if bl.Height != binary.BigEndian.Uint64(mbl.Check[:]) { @@ -194,7 +197,9 @@ func (chain *Blockchain) Check_Dynamism(mbls []block.MiniBlock) (err error) { // insert a miniblock to chain and if successfull inserted, notify everyone in need func (chain *Blockchain) InsertMiniBlock(mbl block.MiniBlock) (err error, result bool) { - + if !mbl.IsSafe() { + return fmt.Errorf("miniblock is unsafe"), false + } var miner_hash crypto.Hash copy(miner_hash[:], mbl.KeyHash[:]) if !chain.IsAddressHashValid(true, miner_hash) { diff --git a/blockchain/regpool/regpool.go b/blockchain/regpool/regpool.go index 165ac9a6..1097546a 100644 --- a/blockchain/regpool/regpool.go +++ b/blockchain/regpool/regpool.go @@ -55,8 +55,6 @@ type Regpool struct { modified bool // used to monitor whethel mem pool contents have changed, height uint64 // track blockchain height - relayer chan crypto.Hash // used for immediate relay - // global variable , but don't see it utilisation here except fot tx verification //chain *Blockchain Exit_Mutex chan bool @@ -136,7 +134,6 @@ func Init_Regpool(params map[string]interface{}) (*Regpool, error) { loggerpool.Info("Regpool started") atomic.AddUint32(&globals.Subsystem_Active, 1) // increment subsystem - regpool.relayer = make(chan crypto.Hash, 1024*10) regpool.Exit_Mutex = make(chan bool) metrics.Set.GetOrCreateGauge("regpool_count", func() float64 { @@ -259,7 +256,6 @@ func (pool *Regpool) Regpool_Add_TX(tx *transaction.Transaction, Height uint64) object.Size = uint64(len(tx.Serialize())) pool.txs.Store(tx_hash, &object) - pool.relayer <- tx_hash pool.modified = true // pool has been modified //pool.sort_list() // sort and update pool list diff --git a/blockchain/store.go b/blockchain/store.go index c6a280db..add15c0e 100644 --- a/blockchain/store.go +++ b/blockchain/store.go @@ -17,7 +17,6 @@ package blockchain import "fmt" -import "sync" import "math/big" import "crypto/rand" import "path/filepath" @@ -29,8 +28,6 @@ import "github.com/deroproject/derohe/cryptography/crypto" import "github.com/deroproject/graviton" -import "github.com/hashicorp/golang-lru" - // though these can be done within a single DB, these are separated for completely clarity purposes type storage struct { Balance_store *graviton.Store // stores most critical data, only history can be purged, its merkle tree is stored in the block @@ -98,7 +95,7 @@ func (chain *Blockchain) StoreBlock(bl *block.Block, snapshot_version uint64) { chain.Store.Block_tx_store.DeleteBlock(hash) // what should we do on error - err := chain.Store.Block_tx_store.WriteBlock(hash, serialized_bytes, difficulty_of_current_block, snapshot_version) + err := chain.Store.Block_tx_store.WriteBlock(hash, serialized_bytes, difficulty_of_current_block, snapshot_version, bl.Height) if err != nil { panic(fmt.Sprintf("error while writing block")) } @@ -169,36 +166,27 @@ func (chain *Blockchain) Load_Block_Timestamp(h crypto.Hash) uint64 { } func (chain *Blockchain) Load_Block_Height(h crypto.Hash) (height int64) { - defer func() { if r := recover(); r != nil { height = -1 } }() - bl, err := chain.Load_BL_FROM_ID(h) - if err != nil { - panic(err) + if heighti, err := chain.ReadBlockHeight(h); err != nil { + return -1 + } else { + return int64(heighti) } - height = int64(bl.Height) - - return } func (chain *Blockchain) Load_Height_for_BL_ID(h crypto.Hash) int64 { return chain.Load_Block_Height(h) } -var past_cache, _ = lru.New(10240) -var past_cache_lock sync.Mutex - // all the immediate past of a block func (chain *Blockchain) Get_Block_Past(hash crypto.Hash) (blocks []crypto.Hash) { //fmt.Printf("loading tips for block %x\n", hash) - past_cache_lock.Lock() - defer past_cache_lock.Unlock() - - if keysi, ok := past_cache.Get(hash); ok { + if keysi, ok := chain.cache_BlockPast.Get(hash); ok { keys := keysi.([]crypto.Hash) blocks = make([]crypto.Hash, len(keys)) for i := range keys { @@ -223,7 +211,7 @@ func (chain *Blockchain) Get_Block_Past(hash crypto.Hash) (blocks []crypto.Hash) } //set in cache - past_cache.Add(hash, cache_copy) + chain.cache_BlockPast.Add(hash, cache_copy) return } diff --git a/blockchain/storefs.go b/blockchain/storefs.go index e410d050..53329466 100644 --- a/blockchain/storefs.go +++ b/blockchain/storefs.go @@ -24,6 +24,7 @@ import "strings" import "io/ioutil" import "math/big" import "path/filepath" +import "github.com/deroproject/derohe/globals" type storefs struct { basedir string @@ -33,6 +34,7 @@ type storefs struct { // hex block id (64 chars).block._ rewards (decimal) _ difficulty _ cumulative difficulty func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) { + defer globals.Recover(0) var dummy [32]byte if h == dummy { return nil, fmt.Errorf("empty block") @@ -40,7 +42,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) { dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) - files, err := ioutil.ReadDir(dir) + files, err := os.ReadDir(dir) if err != nil { return nil, err } @@ -50,7 +52,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) { if strings.HasPrefix(file.Name(), filename_start) { //fmt.Printf("Reading block with filename %s\n", file.Name()) file := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2]), file.Name()) - return ioutil.ReadFile(file) + return os.ReadFile(file) } } @@ -60,7 +62,7 @@ func (s *storefs) ReadBlock(h [32]byte) ([]byte, error) { func (s *storefs) DeleteBlock(h [32]byte) error { dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) - files, err := ioutil.ReadDir(dir) + files, err := os.ReadDir(dir) if err != nil { return err } @@ -87,7 +89,7 @@ func (s *storefs) DeleteBlock(h [32]byte) error { func (s *storefs) ReadBlockDifficulty(h [32]byte) (*big.Int, error) { dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) - files, err := ioutil.ReadDir(dir) + files, err := os.ReadDir(dir) if err != nil { return nil, err } @@ -99,7 +101,7 @@ func (s *storefs) ReadBlockDifficulty(h [32]byte) (*big.Int, error) { diff := new(big.Int) parts := strings.Split(file.Name(), "_") - if len(parts) != 3 { + if len(parts) != 4 { panic("such filename cannot occur") } @@ -120,7 +122,7 @@ func (chain *Blockchain) ReadBlockSnapshotVersion(h [32]byte) (uint64, error) { func (s *storefs) ReadBlockSnapshotVersion(h [32]byte) (uint64, error) { dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) - files, err := ioutil.ReadDir(dir) + files, err := os.ReadDir(dir) if err != nil { return 0, err } @@ -128,28 +130,65 @@ func (s *storefs) ReadBlockSnapshotVersion(h [32]byte) (uint64, error) { filename_start := fmt.Sprintf("%x.block", h[:]) for _, file := range files { if strings.HasPrefix(file.Name(), filename_start) { + var ssversion uint64 + parts := strings.Split(file.Name(), "_") + if len(parts) != 4 { + panic("such filename cannot occur") + } + _, err := fmt.Sscan(parts[2], &ssversion) + if err != nil { + return 0, err + } + return ssversion, nil + } + } + + return 0, os.ErrNotExist +} + +func (chain *Blockchain) ReadBlockHeight(h [32]byte) (uint64, error) { + if heighti, ok := chain.cache_BlockHeight.Get(h); ok { + height := heighti.(uint64) + return height, nil + } - var diff uint64 + height, err := chain.Store.Block_tx_store.ReadBlockHeight(h) + if err == nil { + chain.cache_BlockHeight.Add(h, height) + } + return height, err +} + +func (s *storefs) ReadBlockHeight(h [32]byte) (uint64, error) { + dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) + files, err := os.ReadDir(dir) + if err != nil { + return 0, err + } + + filename_start := fmt.Sprintf("%x.block", h[:]) + for _, file := range files { + if strings.HasPrefix(file.Name(), filename_start) { + var height uint64 parts := strings.Split(file.Name(), "_") - if len(parts) != 3 { + if len(parts) != 4 { panic("such filename cannot occur") } - - _, err := fmt.Sscan(parts[2], &diff) + _, err := fmt.Sscan(parts[3], &height) if err != nil { return 0, err } - return diff, nil + return height, nil } } return 0, os.ErrNotExist } -func (s *storefs) WriteBlock(h [32]byte, data []byte, difficulty *big.Int, ss_version uint64) (err error) { +func (s *storefs) WriteBlock(h [32]byte, data []byte, difficulty *big.Int, ss_version uint64, height uint64) (err error) { dir := filepath.Join(filepath.Join(s.basedir, "bltx_store"), fmt.Sprintf("%02x", h[0]), fmt.Sprintf("%02x", h[1]), fmt.Sprintf("%02x", h[2])) - file := filepath.Join(dir, fmt.Sprintf("%x.block_%s_%d", h[:], difficulty.String(), ss_version)) + file := filepath.Join(dir, fmt.Sprintf("%x.block_%s_%d_%d", h[:], difficulty.String(), ss_version, height)) if err = os.MkdirAll(dir, 0700); err != nil { return err } diff --git a/blockchain/transaction_verify.go b/blockchain/transaction_verify.go index 8fa9950b..18160562 100644 --- a/blockchain/transaction_verify.go +++ b/blockchain/transaction_verify.go @@ -37,76 +37,31 @@ import "github.com/deroproject/derohe/cryptography/crypto" import "github.com/deroproject/derohe/transaction" import "github.com/deroproject/derohe/cryptography/bn256" -//import "github.com/deroproject/derosuite/emission" - // caches x of transactions validity // it is always atomic -// the cache is not txhash -> validity mapping -// instead it is txhash+expanded ringmembers +// the cache is txhash -> validity mapping // if the entry exist, the tx is valid // it stores special hash and first seen time -// this can only be used on expanded transactions var transaction_valid_cache sync.Map // this go routine continuously scans and cleans up the cache for expired entries func clean_up_valid_cache() { - - for { - time.Sleep(3600 * time.Second) - current_time := time.Now() - - // track propagation upto 10 minutes - transaction_valid_cache.Range(func(k, value interface{}) bool { - first_seen := value.(time.Time) - if current_time.Sub(first_seen).Round(time.Second).Seconds() > 3600 { - transaction_valid_cache.Delete(k) - } - return true - }) - - } + current_time := time.Now() + transaction_valid_cache.Range(func(k, value interface{}) bool { + first_seen := value.(time.Time) + if current_time.Sub(first_seen).Round(time.Second).Seconds() > 3600 { + transaction_valid_cache.Delete(k) + } + return true + }) } -/* Coinbase transactions need to verify registration - * */ +// Coinbase transactions need to verify registration func (chain *Blockchain) Verify_Transaction_Coinbase(cbl *block.Complete_Block, minertx *transaction.Transaction) (err error) { - if !minertx.IsCoinbase() { // transaction is not coinbase, return failed return fmt.Errorf("tx is not coinbase") } - // make sure miner address is registered - - _, topos := chain.Store.Topo_store.binarySearchHeight(int64(cbl.Bl.Height - 1)) - // load all db versions one by one and check whether the root hash matches the one mentioned in the tx - if len(topos) < 1 { - return fmt.Errorf("could not find previous height blocks %d", cbl.Bl.Height-1) - } - - var balance_tree *graviton.Tree - for i := range topos { - - toporecord, err := chain.Store.Topo_store.Read(topos[i]) - if err != nil { - return fmt.Errorf("could not read block at height %d due to error while obtaining toporecord topos %+v processing %d err:%s\n", cbl.Bl.Height-1, topos, i, err) - } - - ss, err := chain.Store.Balance_store.LoadSnapshot(toporecord.State_Version) - if err != nil { - return err - } - - if balance_tree, err = ss.GetTree(config.BALANCE_TREE); err != nil { - return err - } - - if _, err := balance_tree.Get(minertx.MinerAddress[:]); err != nil { - return fmt.Errorf("balance not obtained err %s\n", err) - //return false - } - - } - return nil // success comes last } @@ -229,7 +184,9 @@ func (chain *Blockchain) Verify_Transaction_NonCoinbase_CheckNonce_Tips(hf_versi } } - chain.cache_IsNonceValidTips.Add(tips_string, true) // set in cache + if !chain.cache_disabled { + chain.cache_IsNonceValidTips.Add(tips_string, true) // set in cache + } return nil } @@ -271,7 +228,9 @@ func (chain *Blockchain) verify_Transaction_NonCoinbase_internal(skip_proof bool } if tx.IsRegistrationValid() { - transaction_valid_cache.Store(tx_hash, time.Now()) // signature got verified, cache it + if !chain.cache_disabled { + transaction_valid_cache.Store(tx_hash, time.Now()) // signature got verified, cache it + } return nil } return fmt.Errorf("Registration has invalid signature") @@ -482,7 +441,10 @@ func (chain *Blockchain) verify_Transaction_NonCoinbase_internal(skip_proof bool // these transactions are done if tx.TransactionType == transaction.NORMAL || tx.TransactionType == transaction.BURN_TX || tx.TransactionType == transaction.SC_TX { - transaction_valid_cache.Store(tx_hash, time.Now()) // signature got verified, cache it + if !chain.cache_disabled { + transaction_valid_cache.Store(tx_hash, time.Now()) // signature got verified, cache it + } + return nil } diff --git a/cmd/dero-wallet-cli/easymenu_post_open.go b/cmd/dero-wallet-cli/easymenu_post_open.go index a7bbd52d..02c47463 100644 --- a/cmd/dero-wallet-cli/easymenu_post_open.go +++ b/cmd/dero-wallet-cli/easymenu_post_open.go @@ -130,8 +130,8 @@ func handle_easymenu_post_open_command(l *readline.Instance, line string) (proce reg_tx := wallet.GetRegistrationTX() // at this point we must send the registration transaction - fmt.Fprintf(l.Stderr(), "Wallet address : "+color_green+"%s"+color_white+" is going to be registered.Pls wait till the account is registered.\n", wallet.GetAddress()) + fmt.Fprintf(l.Stderr(), "Registration TXID %s\n", reg_tx.GetHash()) err := wallet.SendTransaction(reg_tx) if err != nil { fmt.Fprintf(l.Stderr(), "sending registration tx err %s\n", err) diff --git a/cmd/dero-wallet-cli/easymenu_pre_open.go b/cmd/dero-wallet-cli/easymenu_pre_open.go index 24469433..5fd6cb49 100644 --- a/cmd/dero-wallet-cli/easymenu_pre_open.go +++ b/cmd/dero-wallet-cli/easymenu_pre_open.go @@ -38,9 +38,6 @@ func display_easymenu_pre_open_command(l *readline.Instance) { io.WriteString(w, "\t\033[1m2\033[0m\tCreate New Wallet\n") io.WriteString(w, "\t\033[1m3\033[0m\tRecover Wallet using recovery seed (25 words)\n") io.WriteString(w, "\t\033[1m4\033[0m\tRecover Wallet using recovery key (64 char private spend key hex)\n") - // io.WriteString(w, "\t\033[1m5\033[0m\tCreate Watch-able Wallet (view only) using wallet view key\n") - // io.WriteString(w, "\t\033[1m6\033[0m\tRecover Non-deterministic Wallet key\n") - io.WriteString(w, "\n\t\033[1m9\033[0m\tExit menu and start prompt\n") io.WriteString(w, "\t\033[1m0\033[0m\tExit Wallet\n") } diff --git a/cmd/dero-wallet-cli/main.go b/cmd/dero-wallet-cli/main.go index 85281423..3ff71d52 100644 --- a/cmd/dero-wallet-cli/main.go +++ b/cmd/dero-wallet-cli/main.go @@ -355,10 +355,9 @@ func main() { func update_prompt(l *readline.Instance) { last_wallet_height := uint64(0) - last_daemon_height := uint64(0) + last_daemon_height := int64(0) daemon_online := false last_update_time := int64(0) - for { time.Sleep(30 * time.Millisecond) // give user a smooth running number @@ -385,7 +384,8 @@ func update_prompt(l *readline.Instance) { } if wallet == nil { - l.SetPrompt(fmt.Sprintf("\033[1m\033[32m%s \033[0m"+color_green+"0/%d \033[32m>>>\033[0m ", address_trim, 0)) + l.SetPrompt(fmt.Sprintf("\033[1m\033[32m%s \033[0m"+color_green+"0/%d \033[32m>>>\033[0m ", address_trim, walletapi.Get_Daemon_Height())) + l.Refresh() prompt_mutex.Unlock() continue } @@ -395,7 +395,7 @@ func update_prompt(l *readline.Instance) { _ = daemon_online //fmt.Printf("chekcing if update is required\n") - if last_wallet_height != wallet.Get_Height() || last_daemon_height != wallet.Get_Daemon_Height() || + if last_wallet_height != wallet.Get_Height() || last_daemon_height != walletapi.Get_Daemon_Height() || /*daemon_online != wallet.IsDaemonOnlineCached() ||*/ (time.Now().Unix()-last_update_time) >= 1 { // choose color based on urgency color := "\033[32m" // default is green color @@ -403,7 +403,7 @@ func update_prompt(l *readline.Instance) { color = "\033[33m" // make prompt yellow } - dheight := wallet.Get_Daemon_Height() + //dheight := walletapi.Get_Daemon_Height() /*if wallet.IsDaemonOnlineCached() == false { color = "\033[33m" // make prompt yellow @@ -427,10 +427,10 @@ func update_prompt(l *readline.Instance) { testnet_string = "\033[31m TESTNET" } - l.SetPrompt(fmt.Sprintf("\033[1m\033[32m%s \033[0m"+color+"%d/%d %s %s\033[32m>>>\033[0m ", address_trim, wallet.Get_Height(), dheight, balance_string, testnet_string)) + l.SetPrompt(fmt.Sprintf("\033[1m\033[32m%s \033[0m"+color+"%d/%d %s %s\033[32m>>>\033[0m ", address_trim, wallet.Get_Height(), walletapi.Get_Daemon_Height(), balance_string, testnet_string)) l.Refresh() last_wallet_height = wallet.Get_Height() - last_daemon_height = wallet.Get_Daemon_Height() + last_daemon_height = walletapi.Get_Daemon_Height() last_update_time = time.Now().Unix() //daemon_online = wallet.IsDaemonOnlineCached() _ = last_update_time diff --git a/cmd/derod/main.go b/cmd/derod/main.go index 80503f11..9a50f5ac 100644 --- a/cmd/derod/main.go +++ b/cmd/derod/main.go @@ -47,6 +47,7 @@ import "gopkg.in/natefinch/lumberjack.v2" import "github.com/deroproject/derohe/p2p" import "github.com/deroproject/derohe/globals" import "github.com/deroproject/derohe/block" +import "github.com/deroproject/derohe/transaction" import "github.com/deroproject/derohe/config" import "github.com/deroproject/derohe/rpc" import "github.com/deroproject/derohe/blockchain" @@ -199,10 +200,12 @@ func main() { p2p.Broadcast_Block(cbl, peerid) } - chain.P2P_MiniBlock_Relayer = func(mbl block.MiniBlock, peerid uint64) { + chain.P2P_MiniBlock_Relayer = func(mbl []block.MiniBlock, peerid uint64) { p2p.Broadcast_MiniBlock(mbl, peerid) } + globals.Cron.Start() // start cron jobs + // This tiny goroutine continuously updates status as required go func() { last_our_height := int64(0) @@ -226,14 +229,6 @@ func main() { mempool_tx_count := len(chain.Mempool.Mempool_List_TX()) regpool_tx_count := len(chain.Regpool.Regpool_List_TX()) - /*if our_height < 0 { // somehow the data folder got deleted/renamed/corrupted - logger.Error(nil, "Somehow the data directory is not accessible. shutting down") - l.Terminal.ExitRawMode() - l.Terminal.Print("\n\n") - os.Exit(-1) - return - }*/ - // only update prompt if needed if last_second != time.Now().Unix() || last_our_height != our_height || last_best_height != best_height || last_peer_count != peer_count || last_topo_height != topo_height || last_mempool_tx_count != mempool_tx_count || last_regpool_tx_count != regpool_tx_count { // choose color based on urgency @@ -605,37 +600,39 @@ restart_loop: } case command == "print_tx": - /* - - if len(line_parts) == 2 && len(line_parts[1]) == 64 { - txid, err := hex.DecodeString(strings.ToLower(line_parts[1])) - - if err != nil { - fmt.Printf("err while decoding txid err %s\n", err) - continue - } - var hash crypto.Hash - copy(hash[:32], []byte(txid)) + if len(line_parts) == 2 && len(line_parts[1]) == 64 { + txid, err := hex.DecodeString(strings.ToLower(line_parts[1])) - tx, err := chain.Load_TX_FROM_ID(nil, hash) - if err == nil { - //s_bytes := tx.Serialize() - //fmt.Printf("tx : %x\n", s_bytes) - json_bytes, err := json.MarshalIndent(tx, "", " ") - _ = err - fmt.Printf("%s\n", string(json_bytes)) + if err != nil { + fmt.Printf("err while decoding txid err %s\n", err) + continue + } + var hash crypto.Hash + copy(hash[:32], []byte(txid)) - //tx.RctSignature.Message = ringct.Key(tx.GetPrefixHash()) - //ringct.Get_pre_mlsag_hash(tx.RctSignature) - //chain.Expand_Transaction_v2(tx) + var tx transaction.Transaction + if tx_bytes, err := chain.Store.Block_tx_store.ReadTX(hash); err != nil { + fmt.Printf("err while reading txid err %s\n", err) + continue + } else if err = tx.Deserialize(tx_bytes); err != nil { + fmt.Printf("err deserializing tx err %s\n", err) + continue + } - } else { - fmt.Printf("Err %s\n", err) - } + if valid_blid, invalid, valid := chain.IS_TX_Valid(hash); valid { + fmt.Printf("TX is valid in block %s\n", valid_blid) + } else if len(invalid) == 0 { + fmt.Printf("TX is mined in a side chain\n") } else { - fmt.Printf("print_tx needs a single transaction id as arugument\n") + fmt.Printf("TX is mined in blocks %+v\n", invalid) + } + if tx.IsRegistration() { + fmt.Printf("Registration TX validity could not be detected\n") } - */ + + } else { + fmt.Printf("print_tx needs a single transaction id as arugument\n") + } case strings.ToLower(line) == "status": inc, out := p2p.Peer_Direction_Count() @@ -782,46 +779,49 @@ restart_loop: logger.Error(fmt.Errorf("POP needs argument n to pop this many blocks from the top"), "") } + case command == "gc": + runtime.GC() + case command == "ban": - /* - if len(line_parts) >= 4 || len(line_parts) == 1 { - fmt.Printf("IP address required to ban\n") - break - } - if len(line_parts) == 3 { // process ban time if provided - // if user provided a time, apply ban for specific time - if s, err := strconv.ParseInt(line_parts[2], 10, 64); err == nil && s >= 0 { - p2p.Ban_Address(line_parts[1], uint64(s)) - break - } else { - fmt.Printf("err parsing ban time (only positive number) %s", err) - break - } - } + if len(line_parts) >= 4 || len(line_parts) == 1 { + fmt.Printf("IP address required to ban\n") + break + } - err := p2p.Ban_Address(line_parts[1], 10*60) // default ban is 10 minutes - if err != nil { - fmt.Printf("err parsing address %s", err) + if len(line_parts) == 3 { // process ban time if provided + // if user provided a time, apply ban for specific time + if s, err := strconv.ParseInt(line_parts[2], 10, 64); err == nil && s >= 0 { + p2p.Ban_Address(line_parts[1], uint64(s)) break - } - */ - case command == "unban": - /* - if len(line_parts) >= 3 || len(line_parts) == 1 { - fmt.Printf("IP address required to unban\n") + } else { + fmt.Printf("err parsing ban time (only positive number) %s", err) break } + } + + err := p2p.Ban_Address(line_parts[1], 10*60) // default ban is 10 minutes + if err != nil { + fmt.Printf("err parsing address %s", err) + break + } + + case command == "unban": + + if len(line_parts) >= 3 || len(line_parts) == 1 { + fmt.Printf("IP address required to unban\n") + break + } + + err := p2p.UnBan_Address(line_parts[1]) + if err != nil { + fmt.Printf("err unbanning %s, err = %s", line_parts[1], err) + } else { + fmt.Printf("unbann %s successful", line_parts[1]) + } - err := p2p.UnBan_Address(line_parts[1]) - if err != nil { - fmt.Printf("err unbanning %s, err = %s", line_parts[1], err) - } else { - fmt.Printf("unbann %s successful", line_parts[1]) - } - */ case command == "bans": - //p2p.BanList_Print() // print ban list + p2p.BanList_Print() // print ban list case line == "sleep": logger.Info("console sleeping for 1 second") @@ -948,6 +948,7 @@ func usage(w io.Writer) { var completer = readline.NewPrefixCompleter( readline.PcItem("help"), readline.PcItem("diff"), + readline.PcItem("gc"), readline.PcItem("mempool_flush"), readline.PcItem("mempool_delete_tx"), readline.PcItem("mempool_print"), diff --git a/cmd/derod/rpc/rpc_dero_getsc.go b/cmd/derod/rpc/rpc_dero_getsc.go index 965a529a..cea95a41 100644 --- a/cmd/derod/rpc/rpc_dero_getsc.go +++ b/cmd/derod/rpc/rpc_dero_getsc.go @@ -106,7 +106,7 @@ func GetSC(ctx context.Context, p rpc.GetSC_Params) (result rpc.GetSC_Result, er _ = k _ = v - fmt.Printf("key '%x' value '%x'\n", k, v) + //fmt.Printf("key '%x' value '%x'\n", k, v) if len(k) == 32 && len(v) == 8 { // it's SC balance result.Balances[fmt.Sprintf("%x", k)] = binary.BigEndian.Uint64(v) } else if k[len(k)-1] >= 0x3 && k[len(k)-1] < 0x80 && nil == vark.UnmarshalBinary(k) && nil == varv.UnmarshalBinary(v) { diff --git a/cmd/simulator/simulator.go b/cmd/simulator/simulator.go index e2d8afcd..de31c31a 100644 --- a/cmd/simulator/simulator.go +++ b/cmd/simulator/simulator.go @@ -527,9 +527,13 @@ func mine_block_auto(chain *blockchain.Blockchain, miner_address rpc.Address) { last_block_time := time.Now() for { + bl, _, _, _, err := chain.Create_new_block_template_mining(miner_address) + if err != nil { + logger.Error(err, "error while building mining block") + } + if time.Now().Sub(last_block_time) > time.Duration(config.BLOCK_TIME)*time.Second || // every X secs generate a block - len(chain.Mempool.Mempool_List_TX_SortedInfo()) >= 1 || - len(chain.Regpool.Regpool_List_TX()) >= 1 { //pools have a tx, try to mine them ASAP + len(bl.Tx_hashes) >= 1 { //pools have a tx, try to mine them ASAP if err := mine_block_single(chain, miner_address); err != nil { time.Sleep(time.Second) diff --git a/config/config.go b/config/config.go index e223ee16..5719ec34 100644 --- a/config/config.go +++ b/config/config.go @@ -56,7 +56,7 @@ const MAINNET_BOOTSTRAP_DIFFICULTY = uint64(80000000) // atlantis mainnet botstr const MAINNET_MINIMUM_DIFFICULTY = uint64(800000000) // 80 MH/s // testnet bootstraps at 1 MH -const TESTNET_BOOTSTRAP_DIFFICULTY = uint64(50000) // testnet bootstrap at 50KH/s +const TESTNET_BOOTSTRAP_DIFFICULTY = uint64(10000) // testnet bootstrap at 50KH/s const TESTNET_MINIMUM_DIFFICULTY = uint64(10000) // 10KH/s // this single parameter controls lots of various parameters diff --git a/config/seed_nodes.go b/config/seed_nodes.go index cab4d186..92837e9e 100644 --- a/config/seed_nodes.go +++ b/config/seed_nodes.go @@ -29,5 +29,5 @@ var Mainnet_seed_nodes = []string{ // some seed node for testnet var Testnet_seed_nodes = []string{ - "212.8.242.60:40401", + "68.183.12.117:40401", } diff --git a/config/version.go b/config/version.go index a5c3b303..2ebb26fa 100644 --- a/config/version.go +++ b/config/version.go @@ -20,4 +20,4 @@ import "github.com/blang/semver/v4" // right now it has to be manually changed // do we need to include git commitsha?? -var Version = semver.MustParse("3.4.69-1.DEROHE.STARGATE+15112021") +var Version = semver.MustParse("3.4.80-1.DEROHE.STARGATE+20112021") diff --git a/globals/globals.go b/globals/globals.go index 306a9798..5cef491d 100644 --- a/globals/globals.go +++ b/globals/globals.go @@ -33,6 +33,7 @@ import "go.uber.org/zap" import "go.uber.org/zap/zapcore" import "github.com/go-logr/logr" import "github.com/go-logr/zapr" +import "github.com/robfig/cron/v3" import "github.com/deroproject/derohe/config" import "github.com/deroproject/derohe/rpc" @@ -89,6 +90,10 @@ func GetOffsetP2P() time.Duration { return ClockOffsetP2P } +var Cron = cron.New(cron.WithChain( + cron.Recover(Logger), // or use cron.DefaultLogger +)) + var Dialer proxy.Dialer = proxy.Direct // for proxy and direct connections // all outgoing connections , including DNS requests must be made using this diff --git a/metrics/metrics.go b/metrics/metrics.go index 5f7d2b4d..f650eee9 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -29,6 +29,7 @@ import "net/http" import "path/filepath" import "github.com/go-logr/logr" import "github.com/VictoriaMetrics/metrics" +import "github.com/xtaci/kcp-go/v5" // these are exported by the daemon for various analysis var Version string //this is later converted to metrics format @@ -60,6 +61,33 @@ func writePrometheusMetrics(w io.Writer) { usage := NewDiskUsage(".") fmt.Fprintf(w, "free_disk_space_bytes %d\n", usage.Available()) + + // write kcp metrics, see https://github.com/xtaci/kcp-go/blob/v5.4.20/snmp.go#L9 + fmt.Fprintf(w, "KCP_BytesSent %d\n", kcp.DefaultSnmp.BytesSent) + fmt.Fprintf(w, "KCP_BytesReceived %d\n", kcp.DefaultSnmp.BytesReceived) + fmt.Fprintf(w, "KCP_MaxConn %d\n", kcp.DefaultSnmp.MaxConn) + fmt.Fprintf(w, "KCP_ActiveOpens %d\n", kcp.DefaultSnmp.ActiveOpens) + fmt.Fprintf(w, "KCP_PassiveOpens %d\n", kcp.DefaultSnmp.PassiveOpens) + fmt.Fprintf(w, "KCP_CurrEstab %d\n", kcp.DefaultSnmp.CurrEstab) + fmt.Fprintf(w, "KCP_InErrs %d\n", kcp.DefaultSnmp.InErrs) + fmt.Fprintf(w, "KCP_InCsumErrors %d\n", kcp.DefaultSnmp.InCsumErrors) + fmt.Fprintf(w, "KCP_KCPInErrors %d\n", kcp.DefaultSnmp.KCPInErrors) + fmt.Fprintf(w, "KCP_InPkts %d\n", kcp.DefaultSnmp.InPkts) + fmt.Fprintf(w, "KCP_OutPkts %d\n", kcp.DefaultSnmp.OutPkts) + fmt.Fprintf(w, "KCP_InSegs %d\n", kcp.DefaultSnmp.InSegs) + fmt.Fprintf(w, "KCP_OutSegs %d\n", kcp.DefaultSnmp.OutSegs) + fmt.Fprintf(w, "KCP_InBytes %d\n", kcp.DefaultSnmp.InBytes) + fmt.Fprintf(w, "KCP_OutBytes %d\n", kcp.DefaultSnmp.OutBytes) + fmt.Fprintf(w, "KCP_RetransSegs %d\n", kcp.DefaultSnmp.RetransSegs) + fmt.Fprintf(w, "KCP_FastRetransSegs %d\n", kcp.DefaultSnmp.FastRetransSegs) + fmt.Fprintf(w, "KCP_EarlyRetransSegs %d\n", kcp.DefaultSnmp.EarlyRetransSegs) + fmt.Fprintf(w, "KCP_LostSegs %d\n", kcp.DefaultSnmp.LostSegs) + fmt.Fprintf(w, "KCP_RepeatSegs %d\n", kcp.DefaultSnmp.RepeatSegs) + fmt.Fprintf(w, "KCP_FECRecovered %d\n", kcp.DefaultSnmp.FECRecovered) + fmt.Fprintf(w, "KCP_FECErrs %d\n", kcp.DefaultSnmp.FECErrs) + fmt.Fprintf(w, "KCP_FECParityShards %d\n", kcp.DefaultSnmp.FECParityShards) + fmt.Fprintf(w, "KCP_FECShortShards %d\n", kcp.DefaultSnmp.FECShortShards) + } func Dump_metrics_data_directly(logger logr.Logger, specificnamei interface{}) { diff --git a/p2p/chain_bootstrap.go b/p2p/chain_bootstrap.go index e2a8bca5..162cf1f0 100644 --- a/p2p/chain_bootstrap.go +++ b/p2p/chain_bootstrap.go @@ -20,6 +20,7 @@ import "fmt" //import "net" import "time" +import "context" import "math/big" import "math/bits" import "sync/atomic" @@ -40,6 +41,8 @@ import "github.com/deroproject/derohe/cryptography/crypto" // we are expecting other side to have a heavier PoW chain // this is for the case when the chain only moves in pruned state // if after bootstraping the chain can continousky sync for few minutes, this means we have got the job done +// TODO if during bootstrap error occurs, then we must discard data and restart from scratch +// resume may be implemented in future func (connection *Connection) bootstrap_chain() { defer handle_connection_panic(connection) var request ChangeList @@ -55,6 +58,8 @@ func (connection *Connection) bootstrap_chain() { return } + var TimeLimit = 10 * time.Second + // we will request top 60 blocks ctopo := connection.TopoHeight - 50 // last 50 blocks have to be synced, this syncing will help us detect error var topos []int64 @@ -69,7 +74,9 @@ func (connection *Connection) bootstrap_chain() { } fill_common(&request.Common) // fill common info - if err := connection.RConn.Client.Call("Peer.ChangeSet", request, &response); err != nil { + + ctx, _ := context.WithTimeout(context.Background(), TimeLimit) + if err := connection.Client.CallWithContext(ctx, "Peer.ChangeSet", request, &response); err != nil { connection.logger.V(1).Error(err, "Call failed ChangeSet") return } @@ -103,8 +110,9 @@ func (connection *Connection) bootstrap_chain() { ts_request := Request_Tree_Section_Struct{Topo: request.TopoHeights[0], TreeName: []byte(config.BALANCE_TREE), Section: section[:], SectionLength: uint64(path_length)} var ts_response Response_Tree_Section_Struct fill_common(&ts_response.Common) - if err := connection.RConn.Client.Call("Peer.TreeSection", ts_request, &ts_response); err != nil { - connection.logger.V(2).Error(err, "Call failed TreeSection") + ctx, _ := context.WithTimeout(context.Background(), TimeLimit) + if err := connection.Client.CallWithContext(ctx, "Peer.TreeSection", ts_request, &ts_response); err != nil { + connection.logger.V(1).Error(err, "Call failed TreeSection") return } else { // now we must write all the state changes to gravition @@ -167,8 +175,9 @@ func (connection *Connection) bootstrap_chain() { ts_request := Request_Tree_Section_Struct{Topo: request.TopoHeights[0], TreeName: []byte(config.SC_META), Section: section[:], SectionLength: uint64(path_length)} var ts_response Response_Tree_Section_Struct fill_common(&ts_response.Common) - if err := connection.RConn.Client.Call("Peer.TreeSection", ts_request, &ts_response); err != nil { - connection.logger.V(2).Error(err, "Call failed TreeSection") + ctx, _ = context.WithTimeout(context.Background(), TimeLimit) + if err := connection.Client.CallWithContext(ctx, "Peer.TreeSection", ts_request, &ts_response); err != nil { + connection.logger.V(1).Error(err, "Call failed TreeSection") return } else { // now we must write all the state changes to gravition @@ -197,8 +206,9 @@ func (connection *Connection) bootstrap_chain() { sc_request := Request_Tree_Section_Struct{Topo: request.TopoHeights[0], TreeName: ts_response.Keys[j], Section: section[:], SectionLength: uint64(0)} var sc_response Response_Tree_Section_Struct fill_common(&sc_response.Common) - if err := connection.RConn.Client.Call("Peer.TreeSection", sc_request, &sc_response); err != nil { - connection.logger.V(2).Error(err, "Call failed TreeSection") + ctx, _ = context.WithTimeout(context.Background(), TimeLimit) + if err := connection.Client.CallWithContext(ctx, "Peer.TreeSection", sc_request, &sc_response); err != nil { + connection.logger.V(1).Error(err, "Call failed TreeSection") return } else { var sc_data_tree *graviton.Tree @@ -327,7 +337,7 @@ func (connection *Connection) bootstrap_chain() { } } - if err = chain.Store.Block_tx_store.WriteBlock(bl.GetHash(), bl.Serialize(), diff, commit_version); err != nil { + if err = chain.Store.Block_tx_store.WriteBlock(bl.GetHash(), bl.Serialize(), diff, commit_version, bl.Height); err != nil { panic(fmt.Sprintf("error while writing block")) } diff --git a/p2p/chain_sync.go b/p2p/chain_sync.go index e34d9d46..ca585756 100644 --- a/p2p/chain_sync.go +++ b/p2p/chain_sync.go @@ -16,24 +16,17 @@ package p2p -//import "fmt" - -//import "net" +import "fmt" import "time" +import "context" import "sync/atomic" -//import "container/list" - import "github.com/deroproject/derohe/config" import "github.com/deroproject/derohe/globals" import "github.com/deroproject/derohe/block" import "github.com/deroproject/derohe/errormsg" import "github.com/deroproject/derohe/transaction" -//import "github.com/deroproject/derohe/cryptography/crypto" - -//import "github.com/deroproject/derosuite/blockchain" - // we are expecting other side to have a heavier PoW chain, try to sync now func (connection *Connection) sync_chain() { @@ -75,8 +68,11 @@ try_again: request.Block_list = append(request.Block_list, globals.Config.Genesis_Block_Hash) request.TopoHeights = append(request.TopoHeights, 0) fill_common(&request.Common) // fill common info - if err := connection.RConn.Client.Call("Peer.Chain", request, &response); err != nil { - connection.logger.V(2).Error(err, "Call failed Chain", err) + + var TimeLimit = 10 * time.Second + ctx, _ := context.WithTimeout(context.Background(), TimeLimit) + if err := connection.Client.CallWithContext(ctx, "Peer.Chain", request, &response); err != nil { + connection.logger.V(2).Error(err, "Call failed Chain") return } // we have a response, see if its valid and try to add to get the blocks @@ -111,7 +107,7 @@ try_again: connection.logger.V(2).Info("response block list", "count", len(response.Block_list)) for i := range response.Block_list { our_topo_order := chain.Load_Block_Topological_order(response.Block_list[i]) - if our_topo_order != (int64(i)+response.Start_topoheight) || chain.Load_Block_Topological_order(response.Block_list[i]) == -1 { // if block is not in our chain, add it to request list + if our_topo_order != (int64(i)+response.Start_topoheight) || our_topo_order == -1 { // if block is not in our chain, add it to request list //queue_block(request.Block_list[i]) if max_blocks_to_queue >= 0 { max_blocks_to_queue-- @@ -121,7 +117,8 @@ try_again: orequest.Block_list = append(orequest.Block_list, response.Block_list[i]) fill_common(&orequest.Common) - if err := connection.RConn.Client.Call("Peer.GetObject", orequest, &oresponse); err != nil { + ctx, _ := context.WithTimeout(context.Background(), TimeLimit) + if err := connection.Client.CallWithContext(ctx, "Peer.GetObject", orequest, &oresponse); err != nil { connection.logger.V(2).Error(err, "Call failed GetObject") return } else { // process the response @@ -130,10 +127,10 @@ try_again: } } - //fmt.Printf("Queuing block %x height %d %s", response.Block_list[i], response.Start_height+int64(i), connection.logid) + // fmt.Printf("Queuing block %x height %d %s", response.Block_list[i], response.Start_height+int64(i), connection.logid) } } else { - connection.logger.V(3).Info("We must have queued but we skipped it at height", "blid", response.Block_list[i], "height", response.Start_height+int64(i)) + connection.logger.V(3).Info("We must have queued but we skipped it at height", "blid", fmt.Sprintf("%x", response.Block_list[i]), "height", response.Start_height+int64(i)) } } diff --git a/p2p/chunk_server.go b/p2p/chunk_server.go index 05cb39e2..f9b51070 100644 --- a/p2p/chunk_server.go +++ b/p2p/chunk_server.go @@ -36,18 +36,13 @@ type Chunks_Per_Block_Data struct { // cleans up chunks every minute func chunks_clean_up() { - for { - time.Sleep(5 * time.Second) // cleanup every 5 seconds - - chunk_map.Range(func(key, value interface{}) bool { - - chunks_per_block := value.(*Chunks_Per_Block_Data) - if time.Now().Sub(chunks_per_block.Created) > time.Second*180 { - chunk_map.Delete(key) - } - return true - }) - } + chunk_map.Range(func(key, value interface{}) bool { + chunks_per_block := value.(*Chunks_Per_Block_Data) + if time.Now().Sub(chunks_per_block.Created) > time.Second*180 { + chunk_map.Delete(key) + } + return true + }) } // return whether chunk exist @@ -64,6 +59,9 @@ func is_chunk_exist(hhash [32]byte, cid uint8) *Block_Chunk { // feed a chunk until we are able to fully decode a chunk func (connection *Connection) feed_chunk(chunk *Block_Chunk, sent int64) error { + chunk_lock.Lock() + defer chunk_lock.Unlock() + if chunk.HHash != chunk.HeaderHash() { connection.logger.V(2).Info("This peer should be banned, since he supplied wrong chunk") connection.exit() diff --git a/p2p/common.go b/p2p/common.go index ac7b49ec..3e25623e 100644 --- a/p2p/common.go +++ b/p2p/common.go @@ -47,8 +47,7 @@ func fill_common_skip_topoheight(common *Common_Struct) { // update some common properties quickly func (connection *Connection) update(common *Common_Struct) { - //connection.Lock() - //defer connection.Unlock() + connection.update_received = time.Now() var hash crypto.Hash atomic.StoreInt64(&connection.Height, common.Height) // satify race detector GOD if common.StableHeight != 0 { diff --git a/p2p/connection_pool.go b/p2p/connection_pool.go index 0010f128..ec5e60c7 100644 --- a/p2p/connection_pool.go +++ b/p2p/connection_pool.go @@ -27,7 +27,7 @@ import "sync" import "sort" import "time" import "strings" -import "math/rand" +import "context" import "sync/atomic" import "runtime/debug" @@ -35,14 +35,14 @@ import "github.com/go-logr/logr" import "github.com/dustin/go-humanize" -import "github.com/paulbellamy/ratecounter" - import "github.com/deroproject/derohe/block" import "github.com/deroproject/derohe/cryptography/crypto" import "github.com/deroproject/derohe/globals" import "github.com/deroproject/derohe/metrics" import "github.com/deroproject/derohe/transaction" +import "github.com/cenkalti/rpc2" + // any connection incoming/outgoing can only be in this state //type Conn_State uint32 @@ -52,133 +52,136 @@ const ( ACTIVE = 2 // "Active" ) -type Queued_Command struct { - Command uint64 // we are waiting for this response - BLID []crypto.Hash - TXID []crypto.Hash - Topos []int64 -} - const MAX_CLOCK_DATA_SET = 16 // This structure is used to do book keeping for the connection and keeps other DATA related to peer // golang restricts 64 bit uint64/int atomic on a 64 bit boundary // therefore all atomics are on the top type Connection struct { + Client *rpc2.Client + Conn net.Conn // actual object to talk + ConnTls net.Conn // tls layered conn + Height int64 // last height sent by peer ( first member alignments issues) StableHeight int64 // last stable height TopoHeight int64 // topo height, current topo height, this is the only thing we require for syncing StateHash crypto.Hash // statehash at the top Pruned int64 // till where chain has been pruned on this node - LastObjectRequestTime int64 // when was the last item placed in object list - BytesIn uint64 // total bytes in - BytesOut uint64 // total bytes out - Latency int64 // time.Duration // latency to this node when sending timed sync - - Incoming bool // is connection incoming or outgoing - Addr *net.TCPAddr // endpoint on the other end - Port uint32 // port advertised by other end as its server,if it's 0 server cannot accept connections - Peer_ID uint64 // Remote peer id - SyncNode bool // whether the peer has been added to command line as sync node - Top_Version uint64 // current hard fork version supported by peer + Created time.Time // when was object created + LastObjectRequestTime int64 // when was the last item placed in object list + BytesIn uint64 // total bytes in + BytesOut uint64 // total bytes out + Latency int64 // time.Duration // latency to this node when sending timed sync + + Incoming bool // is connection incoming or outgoing + Addr net.Addr // endpoint on the other end + Port uint32 // port advertised by other end as its server,if it's 0 server cannot accept connections + Peer_ID uint64 // Remote peer id + SyncNode bool // whether the peer has been added to command line as sync node + Top_Version uint64 // current hard fork version supported by peer ProtocolVersion string Tag string // tag for the other end DaemonVersion string - //Exit chan bool // Exit marker that connection needs to be killed - ExitCounter int32 - State uint32 // state of the connection - Top_ID crypto.Hash // top block id of the connection - - logger logr.Logger // connection specific logger - logid string // formatted version of connection - Requested_Objects [][32]byte // currently unused as we sync up with a single peer at a time - Conn net.Conn // actual object to talk - RConn *RPC_Connection // object for communication - // Command_queue *list.List // New protocol is partly syncronous - Objects chan Queued_Command // contains all objects that are requested - SpeedIn *ratecounter.RateCounter // average speed in last 60 seconds - SpeedOut *ratecounter.RateCounter // average speed in last 60 secs - request_time atomic.Value //time.Time // used to track latency - writelock sync.Mutex // used to Serialize writes - - previous_mbl []byte // single slot cache - - peer_sent_time time.Time // contains last time when peerlist was sent + State uint32 // state of the connection + Top_ID crypto.Hash // top block id of the connection + + logger logr.Logger // connection specific logger + + Requested_Objects [][32]byte // currently unused as we sync up with a single peer at a time + + peer_sent_time time.Time // contains last time when peerlist was sent + update_received time.Time // last time when upated was received + ping_in_progress int32 // contains ping pending against this connection + + ping_count int64 clock_index int clock_offsets [MAX_CLOCK_DATA_SET]time.Duration delays [MAX_CLOCK_DATA_SET]time.Duration clock_offset int64 // duration updated on every miniblock + onceexit sync.Once + Mutex sync.Mutex // used only by connection go routine } +func Address(c *Connection) string { + if c.Addr == nil { + return "" + } + return ParseIPNoError(c.Addr.String()) +} + func (c *Connection) exit() { - c.RConn.Session.Close() + defer globals.Recover(0) + c.onceexit.Do(func() { + c.ConnTls.Close() + c.Conn.Close() + c.Client.Close() + }) } -var connection_map sync.Map // map[string]*Connection{} -var connection_per_ip_counter = map[string]int{} // only keeps the counter of counter of connections +// add connection to map +func Connection_Delete(c *Connection) { + connection_map.Range(func(k, value interface{}) bool { + v := value.(*Connection) + if c.Addr.String() == v.Addr.String() { + connection_map.Delete(Address(v)) + return false + } + return true + }) +} -// for incoming connections we use their peer id to assertain uniquenesss -// for outgoing connections, we use the tcp endpoint address, so as not more than 1 connection is done -func Key(c *Connection) string { - if c.Incoming { - return fmt.Sprintf("%d", c.Peer_ID) - } - return string(c.Addr.String()) // Simple []byte => string conversion +func Connection_Pending_Clear() { + connection_map.Range(func(k, value interface{}) bool { + v := value.(*Connection) + if atomic.LoadUint32(&v.State) == HANDSHAKE_PENDING && time.Now().Sub(v.Created) > 10*time.Second { //and skip ourselves + v.exit() + v.logger.V(3).Info("Cleaning pending connection") + } + + if time.Now().Sub(v.update_received).Round(time.Second).Seconds() > 20 { + v.exit() + Connection_Delete(v) + v.logger.Info("Purging connection due since idle") + } + + if IsAddressInBanList(Address(v)) { + v.exit() + Connection_Delete(v) + v.logger.Info("Purging connection due to ban list") + } + return true + }) } +var connection_map sync.Map // map[string]*Connection{} + // check whether an IP is in the map already func IsAddressConnected(address string) bool { - if _, ok := connection_map.Load(strings.TrimSpace(address)); ok { return true } return false } -// add connection to map +// add connection to map, only if we are not connected already // we also check for limits for incoming connections // same ip max 8 ip ( considering NAT) //same Peer ID 4 -func Connection_Add(c *Connection) { - //connection_mutex.Lock() - //defer connection_mutex.Unlock() - - ip_count := 0 - peer_id_count := 0 - - incoming_ip := c.Addr.IP.String() - incoming_peer_id := c.Peer_ID - - if c.Incoming { // we need extra protection for incoming for various attacks - - connection_map.Range(func(k, value interface{}) bool { - v := value.(*Connection) - if v.Incoming { - if incoming_ip == v.Addr.IP.String() { - ip_count++ - } - - if incoming_peer_id == v.Peer_ID { - peer_id_count++ - } - } - return true - }) - - } - - if ip_count >= 8 || peer_id_count >= 4 { - c.logger.V(3).Info("IP address already has too many connections, exiting this connection", "ip", incoming_ip, "count", ip_count, "peerid", incoming_peer_id) +func Connection_Add(c *Connection) bool { + if dup, ok := connection_map.LoadOrStore(Address(c), c); !ok { + c.Created = time.Now() + c.logger.V(3).Info("IP address being added", "ip", c.Addr.String()) + return true + } else { + c.logger.V(3).Info("IP address already has one connection, exiting this connection", "ip", c.Addr.String(), "pre", dup.(*Connection).Addr.String()) c.exit() - return + return false } - - connection_map.Store(Key(c), c) } // unique connection list @@ -199,34 +202,35 @@ func UniqueConnections() map[uint64]*Connection { // this function has infinite loop to keep ping every few sec func ping_loop() { - for { - time.Sleep(1 * time.Second) - connection_map.Range(func(k, value interface{}) bool { - c := value.(*Connection) - if atomic.LoadUint32(&c.State) != HANDSHAKE_PENDING && GetPeerID() != c.Peer_ID { - go func() { - defer globals.Recover(3) - var request, response Dummy - fill_common(&request.Common) // fill common info + connection_map.Range(func(k, value interface{}) bool { + c := value.(*Connection) + if atomic.LoadUint32(&c.State) != HANDSHAKE_PENDING && GetPeerID() != c.Peer_ID /*&& atomic.LoadInt32(&c.ping_in_progress) == 0*/ { + go func() { + defer globals.Recover(3) + atomic.AddInt32(&c.ping_in_progress, 1) + defer atomic.AddInt32(&c.ping_in_progress, -1) - if c.peer_sent_time.Add(5 * time.Second).Before(time.Now()) { - c.peer_sent_time = time.Now() - request.Common.PeerList = get_peer_list() - } - if err := c.RConn.Client.Call("Peer.Ping", request, &response); err != nil { - return - } - c.update(&response.Common) // update common information - }() - } - return true - }) - } -} + var request, response Dummy + fill_common(&request.Common) // fill common info -// add connection to map -func Connection_Delete(c *Connection) { - connection_map.Delete(Key(c)) + c.ping_count++ + if c.ping_count%100 == 1 { + request.Common.PeerList = get_peer_list() + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := c.Client.CallWithContext(ctx, "Peer.Ping", request, &response); err != nil { + c.logger.V(2).Error(err, "ping failed") + c.exit() + return + } + c.update(&response.Common) // update common information + }() + } + return true + }) } // prints all the connection info to screen @@ -239,15 +243,21 @@ func Connection_Print() { return true }) - logger.Info("Connection info for peers", "count", len(clist)) + version, err := chain.ReadBlockSnapshotVersion(chain.Get_Top_ID()) + if err != nil { + panic(err) + } - if globals.Arguments["--debug"].(bool) == true { - fmt.Printf("%-20s %-16s %-5s %-7s %-7s %-7s %23s %3s %5s %s %s %s %s %16s %16s\n", "Remote Addr", "PEER ID", "PORT", " State", "Latency", "Offset", "S/H/T", "DIR", "QUEUE", " IN", " OUT", " IN SPEED", " OUT SPEED", "Version", "Statehash") - } else { - fmt.Printf("%-20s %-16s %-5s %-7s %-7s %-7s %17s %3s %5s %s %s %s %s %16s %16s\n", "Remote Addr", "PEER ID", "PORT", " State", "Latency", "Offset", "H/T", "DIR", "QUEUE", " IN", " OUT", " IN SPEED", " OUT SPEED", "Version", "Statehash") + StateHash, err := chain.Load_Merkle_Hash(version) + if err != nil { + panic(err) } + logger.Info("Connection info for peers", "count", len(clist), "our Statehash", StateHash) + + fmt.Printf("%-30s %-16s %-5s %-7s %-7s %-7s %23s %3s %5s %s %s %16s %16s\n", "Remote Addr", "PEER ID", "PORT", " State", "Latency", "Offset", "S/H/T", "DIR", "QUEUE", " IN", " OUT", "Version", "Statehash") + // sort the list sort.Slice(clist, func(i, j int) bool { return clist[i].Addr.String() < clist[j].Addr.String() }) @@ -290,15 +300,10 @@ func Connection_Print() { fmt.Print(color_yellow) } - if globals.Arguments["--debug"].(bool) == true { - hstring := fmt.Sprintf("%d/%d/%d", clist[i].StableHeight, clist[i].Height, clist[i].TopoHeight) - fmt.Printf("%-20s %16x %5d %7s %7s %7s %23s %s %5d %7s %7s %8s %9s %16s %s %x\n", clist[i].Addr.IP, clist[i].Peer_ID, clist[i].Port, state, time.Duration(atomic.LoadInt64(&clist[i].Latency)).Round(time.Millisecond).String(), time.Duration(atomic.LoadInt64(&clist[i].clock_offset)).Round(time.Millisecond).String(), hstring, dir, clist[i].isConnectionSyncing(), humanize.Bytes(atomic.LoadUint64(&clist[i].BytesIn)), humanize.Bytes(atomic.LoadUint64(&clist[i].BytesOut)), humanize.Bytes(uint64(clist[i].SpeedIn.Rate()/60)), humanize.Bytes(uint64(clist[i].SpeedOut.Rate()/60)), version, tag, clist[i].StateHash[:]) - - } else { - hstring := fmt.Sprintf("%d/%d", clist[i].Height, clist[i].TopoHeight) - fmt.Printf("%-20s %16x %5d %7s %7s %7s %17s %s %5d %7s %7s %8s %9s %16s %s %x\n", clist[i].Addr.IP, clist[i].Peer_ID, clist[i].Port, state, time.Duration(atomic.LoadInt64(&clist[i].Latency)).Round(time.Millisecond).String(), time.Duration(atomic.LoadInt64(&clist[i].clock_offset)).Round(time.Millisecond).String(), hstring, dir, clist[i].isConnectionSyncing(), humanize.Bytes(atomic.LoadUint64(&clist[i].BytesIn)), humanize.Bytes(atomic.LoadUint64(&clist[i].BytesOut)), humanize.Bytes(uint64(clist[i].SpeedIn.Rate()/60)), humanize.Bytes(uint64(clist[i].SpeedOut.Rate()/60)), version, tag, clist[i].StateHash[:8]) + ctime := time.Now().Sub(clist[i].Created).Round(time.Second) - } + hstring := fmt.Sprintf("%d/%d/%d", clist[i].StableHeight, clist[i].Height, clist[i].TopoHeight) + fmt.Printf("%-30s %16x %5d %7s %7s %7s %23s %s %5d %7s %7s %16s %s %x\n", Address(clist[i])+" ("+ctime.String()+")", clist[i].Peer_ID, clist[i].Port, state, time.Duration(atomic.LoadInt64(&clist[i].Latency)).Round(time.Millisecond).String(), time.Duration(atomic.LoadInt64(&clist[i].clock_offset)).Round(time.Millisecond).String(), hstring, dir, 0, humanize.Bytes(atomic.LoadUint64(&clist[i].BytesIn)), humanize.Bytes(atomic.LoadUint64(&clist[i].BytesOut)), version, tag, clist[i].StateHash[:]) fmt.Print(color_normal) } @@ -328,21 +333,6 @@ func Best_Peer_Height() (best_height, best_topo_height int64) { return } -// this function return peer count which have successful handshake -func Disconnect_All() (Count uint64) { - return - /* - connection_mutex.Lock() - for _, v := range connection_map { - // v.Lock() - close(v.Exit) // close the connection - //v.Unlock() - } - connection_mutex.Unlock() - return - */ -} - // this function return peer count which have successful handshake func Peer_Count() (Count uint64) { connection_map.Range(func(k, value interface{}) bool { @@ -355,29 +345,8 @@ func Peer_Count() (Count uint64) { return } -// this function returnw random connection which have successful handshake -func Random_Connection(height int64) (c *Connection) { - - var clist []*Connection - - connection_map.Range(func(k, value interface{}) bool { - v := value.(*Connection) - if atomic.LoadInt64(&v.Height) >= height { - clist = append(clist, v) - } - return true - }) - - if len(clist) > 0 { - return clist[rand.Int()%len(clist)] - } - - return nil -} - // this returns count of peers in both directions func Peer_Direction_Count() (Incoming uint64, Outgoing uint64) { - connection_map.Range(func(k, value interface{}) bool { v := value.(*Connection) if atomic.LoadUint32(&v.State) != HANDSHAKE_PENDING && GetPeerID() != v.Peer_ID { @@ -389,41 +358,9 @@ func Peer_Direction_Count() (Incoming uint64, Outgoing uint64) { } return true }) - return } -func broadcast_Block_tester(topo int64) (err error) { - - blid, err := chain.Load_Block_Topological_order_at_index(topo) - if err != nil { - return fmt.Errorf("err occurred topo %d err %s\n", topo, err) - } - var cbl block.Complete_Block - bl, err := chain.Load_BL_FROM_ID(blid) - if err != nil { - return err - } - - cbl.Bl = bl - for j := range bl.Tx_hashes { - var tx_bytes []byte - if tx_bytes, err = chain.Store.Block_tx_store.ReadTX(bl.Tx_hashes[j]); err != nil { - return err - } - var tx transaction.Transaction - if err = tx.Deserialize(tx_bytes); err != nil { - return err - } - - cbl.Txs = append(cbl.Txs, &tx) // append all the txs - - } - - Broadcast_Block(&cbl, 0) - return nil -} - func Broadcast_Block(cbl *block.Complete_Block, PeerID uint64) { Broadcast_Block_Coded(cbl, PeerID) } @@ -491,7 +428,7 @@ func broadcast_Block_Coded(cbl *block.Complete_Block, PeerID uint64, first_seen connection.logger.V(3).Info("Sending erasure coded chunk to peer ", "cid", cid) var dummy Dummy fill_common(&peer_specific_list.Common) // fill common info - if err := connection.RConn.Client.Call("Peer.NotifyINV", peer_specific_list, &dummy); err != nil { + if err := connection.Client.Call("Peer.NotifyINV", peer_specific_list, &dummy); err != nil { return } connection.update(&dummy.Common) // update common information @@ -518,15 +455,10 @@ done: // this function is triggerred from 2 points, one when we receive a unknown block which can be successfully added to chain // second from the blockchain which has to relay locally mined blocks as soon as possible func broadcast_Chunk(chunk *Block_Chunk, PeerID uint64, first_seen int64) { // if peerid is provided it is skipped - defer globals.Recover(3) - /*if IsSyncing() { // if we are syncing, do NOT broadcast the block - return - }*/ - our_height := chain.Get_Height() - // build the request once and dispatch it to all possible peers + count := 0 unique_map := UniqueConnections() @@ -565,7 +497,7 @@ func broadcast_Chunk(chunk *Block_Chunk, PeerID uint64, first_seen int64) { // i connection.logger.V(3).Info("Sending erasure coded chunk INV to peer ", "raw", fmt.Sprintf("%x", chunkid), "blid", fmt.Sprintf("%x", chunk.BLID), "cid", chunk.CHUNK_ID, "hhash", fmt.Sprintf("%x", hhash), "exists", nil != is_chunk_exist(hhash, uint8(chunk.CHUNK_ID))) var dummy Dummy fill_common(&peer_specific_list.Common) // fill common info - if err := connection.RConn.Client.Call("Peer.NotifyINV", peer_specific_list, &dummy); err != nil { + if err := connection.Client.Call("Peer.NotifyINV", peer_specific_list, &dummy); err != nil { return } connection.update(&dummy.Common) // update common information @@ -579,17 +511,17 @@ func broadcast_Chunk(chunk *Block_Chunk, PeerID uint64, first_seen int64) { // i // we can only broadcast a block which is in our db // this function is trigger from 2 points, one when we receive a unknown block which can be successfully added to chain // second from the blockchain which has to relay locally mined blocks as soon as possible -func Broadcast_MiniBlock(mbl block.MiniBlock, PeerID uint64) { // if peerid is provided it is skipped - broadcast_MiniBlock(mbl, PeerID, globals.Time().UTC().UnixMicro()) +func Broadcast_MiniBlock(mbls []block.MiniBlock, PeerID uint64) { // if peerid is provided it is skipped + broadcast_MiniBlock(mbls, PeerID, globals.Time().UTC().UnixMicro()) } -func broadcast_MiniBlock(mbl block.MiniBlock, PeerID uint64, first_seen int64) { // if peerid is provided it is skipped +func broadcast_MiniBlock(mbls []block.MiniBlock, PeerID uint64, first_seen int64) { // if peerid is provided it is skipped defer globals.Recover(3) - miniblock_serialized := mbl.Serialize() - var peer_specific_block Objects - peer_specific_block.MiniBlocks = append(peer_specific_block.MiniBlocks, miniblock_serialized) + for _, mbl := range mbls { + peer_specific_block.MiniBlocks = append(peer_specific_block.MiniBlocks, mbl.Serialize()) + } fill_common(&peer_specific_block.Common) // fill common info peer_specific_block.Sent = first_seen @@ -622,7 +554,7 @@ func broadcast_MiniBlock(mbl block.MiniBlock, PeerID uint64, first_seen int64) { defer globals.Recover(3) var dummy Dummy - if err := connection.RConn.Client.Call("Peer.NotifyMiniBlock", peer_specific_block, &dummy); err != nil { + if err := connection.Client.Call("Peer.NotifyMiniBlock", peer_specific_block, &dummy); err != nil { return } connection.update(&dummy.Common) // update common information @@ -682,7 +614,7 @@ func broadcast_Tx(tx *transaction.Transaction, PeerID uint64, sent int64) (relay var dummy Dummy fill_common(&dummy.Common) // fill common info - if err := connection.RConn.Client.Call("Peer.NotifyINV", request, &dummy); err != nil { + if err := connection.Client.Call("Peer.NotifyINV", request, &dummy); err != nil { return } connection.update(&dummy.Common) // update common information @@ -698,33 +630,6 @@ func broadcast_Tx(tx *transaction.Transaction, PeerID uint64, sent int64) (relay return } -//var sync_in_progress bool - -// we can tell whether we are syncing by seeing the pending queue of expected response -// if objects response are queued, we are syncing -// if even one of the connection is syncing, then we are syncronising -// returns a number how many blocks are queued -func (connection *Connection) isConnectionSyncing() (count int) { - //connection.Lock() - //defer connection.Unlock() - - if atomic.LoadUint32(&connection.State) == HANDSHAKE_PENDING { // skip pre-handshake connections - return 0 - } - - // check whether 15 secs have passed, if yes close the connection - // so we can try some other connection - if len(connection.Objects) > 0 { - if time.Now().Unix() >= (13 + atomic.LoadInt64(&connection.LastObjectRequestTime)) { - connection.exit() - return 0 - } - } - - return len(connection.Objects) - -} - // trigger a sync with a random peer func trigger_sync() { defer globals.Recover(3) @@ -800,22 +705,6 @@ func trigger_sync() { } } -//detect if something is queued to any of the peer -// is something is queue we are syncing -func IsSyncing() (result bool) { - - syncing := false - connection_map.Range(func(k, value interface{}) bool { - v := value.(*Connection) - if v.isConnectionSyncing() != 0 { - syncing = true - return false - } - return true - }) - return syncing -} - //go:noinline func Abs(n int64) int64 { if n < 0 { @@ -826,22 +715,18 @@ func Abs(n int64) int64 { // detect whether we are behind any of the connected peers and trigger sync ASAP // randomly with one of the peers -func syncroniser() { - delay := time.NewTicker(time.Second) - for { - select { - case <-Exit_Event: - return - case <-delay.C: - } - calculate_network_time() // calculate time every sec +var single_sync int32 - if !IsSyncing() { - trigger_sync() // check whether we are out of sync - } +func syncroniser() { + + defer atomic.AddInt32(&single_sync, -1) + if atomic.AddInt32(&single_sync, 1) != 1 { + return } + calculate_network_time() // calculate time every sec + trigger_sync() // check whether we are out of sync } // update P2P time @@ -893,58 +778,3 @@ func calculate_network_time() { globals.ClockOffsetP2P = time.Duration(total / count) } - -// will return nil, if no peers available -func random_connection() *Connection { - unique_map := UniqueConnections() - - var clist []*Connection - - for _, value := range unique_map { - clist = append(clist, value) - } - - if len(clist) == 0 { - return nil - } else if len(clist) == 1 { - return clist[0] - } - - // sort the list random - // do random shuffling, can we get away with len/2 random shuffling - globals.Global_Random.Shuffle(len(clist), func(i, j int) { - clist[i], clist[j] = clist[j], clist[i] - }) - - return clist[0] -} - -// this will request a tx -func (c *Connection) request_tx(txid [][32]byte, random bool) (err error) { - var need ObjectList - var oresponse Objects - - need.Tx_list = append(need.Tx_list, txid...) - - connection := c - if random { - connection = random_connection() - } - if connection == nil { - err = fmt.Errorf("No peer available") - return - } - - fill_common(&need.Common) // fill common info - if err = c.RConn.Client.Call("Peer.GetObject", need, &oresponse); err != nil { - c.exit() - return - } else { // process the response - if err = c.process_object_response(oresponse, 0, false); err != nil { - return - } - } - - return - -} diff --git a/p2p/controller.go b/p2p/controller.go index ffc68bfe..f3cff3a4 100644 --- a/p2p/controller.go +++ b/p2p/controller.go @@ -16,17 +16,18 @@ package p2p -//import "os" import "fmt" import "net" -import "net/rpc" + +//import "net/url" import "time" import "sort" +import "sync" import "strings" import "math/big" import "strconv" -//import "crypto/rsa" +import "crypto/sha1" import "crypto/ecdsa" import "crypto/elliptic" @@ -44,6 +45,14 @@ import "github.com/deroproject/derohe/globals" import "github.com/deroproject/derohe/metrics" import "github.com/deroproject/derohe/blockchain" +import "github.com/xtaci/kcp-go/v5" +import "golang.org/x/crypto/pbkdf2" +import "golang.org/x/time/rate" + +import "github.com/cenkalti/rpc2" + +//import "github.com/txthinking/socks5" + var chain *blockchain.Blockchain // external reference to chain var P2P_Port int // this will be exported while doing handshake @@ -58,6 +67,28 @@ var nonbanlist []string // any ips in this list will never be banned var ClockOffset time.Duration //Clock Offset related to all the peer2 connected +// also backoff is used if we have initiated a connect we will not connect to it again for another 10 secs +var backoff = map[string]int64{} // if server receives a connection, then it will not initiate connection to that ip for another 60 secs +var backoff_mutex = sync.Mutex{} + +// return true if we should back off else we can connect +func shouldwebackoff(ip string) bool { + backoff_mutex.Lock() + defer backoff_mutex.Unlock() + + now := time.Now().Unix() + for k, v := range backoff { // random backing off + if v < now { + delete(backoff, k) + } + } + + if backoff[ip] != 0 { // now lets do the test + return true + } + return false +} + // Initialize P2P subsystem func P2P_Init(params map[string]interface{}) error { logger = globals.Logger.WithName("P2P") // all components must use this logger @@ -105,11 +136,13 @@ func P2P_Init(params map[string]interface{}) error { } } - go P2P_Server_v2() // start accepting connections - go P2P_engine() // start outgoing engine - go syncroniser() // start sync engine - go chunks_clean_up() // clean up chunks - go ping_loop() // ping loop + go P2P_Server_v2() // start accepting connections + go P2P_engine() // start outgoing engine + globals.Cron.AddFunc("@every 2s", syncroniser) // start sync engine + globals.Cron.AddFunc("@every 5s", Connection_Pending_Clear) // clean dead connections + globals.Cron.AddFunc("@every 10s", ping_loop) // ping every one + globals.Cron.AddFunc("@every 10s", chunks_clean_up) // clean chunks + go time_check_routine() // check whether server time is in sync using ntp metrics.Set.NewGauge("p2p_peer_count", func() float64 { // set a new gauge @@ -215,52 +248,109 @@ func P2P_engine() { } +func tunekcp(conn *kcp.UDPSession) { + conn.SetACKNoDelay(true) + conn.SetNoDelay(1, 10, 2, 1) // tuning paramters for local stack +} + // will try to connect with given endpoint // will block until the connection dies or is killed func connect_with_endpoint(endpoint string, sync_node bool) { defer globals.Recover(2) - remote_ip, err := net.ResolveTCPAddr("tcp", endpoint) + remote_ip, err := net.ResolveUDPAddr("udp", endpoint) if err != nil { logger.V(3).Error(err, "Resolve address failed:", "endpoint", endpoint) return } + if IsAddressInBanList(ParseIPNoError(remote_ip.IP.String())) { + logger.V(2).Info("Connecting to banned IP is prohibited", "IP", remote_ip.IP.String()) + return + } + // check whether are already connected to this address if yes, return - if IsAddressConnected(remote_ip.String()) { + if IsAddressConnected(ParseIPNoError(remote_ip.String())) { + logger.V(4).Info("outgoing address is already connected", "ip", remote_ip.String()) return //nil, fmt.Errorf("Already connected") } + if shouldwebackoff(ParseIPNoError(remote_ip.String())) { + logger.V(1).Info("backing off from this connection", "ip", remote_ip.String()) + return + } else { + backoff_mutex.Lock() + backoff[ParseIPNoError(remote_ip.String())] = time.Now().Unix() + 10 + backoff_mutex.Unlock() + } + + var masterkey = pbkdf2.Key(globals.Config.Network_ID.Bytes(), globals.Config.Network_ID.Bytes(), 1024, 32, sha1.New) + var blockcipher, _ = kcp.NewAESBlockCrypt(masterkey) + + var conn *kcp.UDPSession + // since we may be connecting through socks, grab the remote ip for our purpose rightnow - conn, err := globals.Dialer.Dial("tcp", remote_ip.String()) + //conn, err := globals.Dialer.Dial("tcp", remote_ip.String()) + if globals.Arguments["--socks-proxy"] == nil { + conn, err = kcp.DialWithOptions(remote_ip.String(), blockcipher, 10, 3) + } else { // we must move through a socks 5 UDP ASSOCIATE supporting proxy, ssh implementation is partial + err = fmt.Errorf("socks proxying is not supported") + logger.V(0).Error(err, "Not suported", "server", globals.Arguments["--socks-proxy"]) + return + /*uri, err := url.Parse("socks5://" + globals.Arguments["--socks-proxy"].(string)) // "socks5://demo:demo@192.168.99.100:1080" + if err != nil { + logger.V(0).Error(err, "Error parsing socks proxy", "server", globals.Arguments["--socks-proxy"]) + return + } + _ = uri + sserver := uri.Host + if uri.Port() != "" { + + host, _, err := net.SplitHostPort(uri.Host) + if err != nil { + logger.V(0).Error(err, "Error parsing socks proxy", "server", globals.Arguments["--socks-proxy"]) + return + } + sserver = host + ":"+ uri.Port() + } + + fmt.Printf("sserver %s host %s port %s\n", sserver, uri.Host, uri.Port()) + username := "" + password := "" + if uri.User != nil { + username = uri.User.Username() + password,_ = uri.User.Password() + } + tcpTimeout := 10 + udpTimeout := 10 + c, err := socks5.NewClient(sserver, username, password, tcpTimeout, udpTimeout) + if err != nil { + logger.V(0).Error(err, "Error connecting to socks proxy", "server", globals.Arguments["--socks-proxy"]) + return + } + udpconn, err := c.Dial("udp", remote_ip.String()) + if err != nil { + logger.V(0).Error(err, "Error connecting to remote host using socks proxy", "socks", globals.Arguments["--socks-proxy"],"remote",remote_ip.String()) + return + } + conn,err = kcp.NewConn(remote_ip.String(),blockcipher,10,3,udpconn) + */ + } - //conn, err := tls.DialWithDialer(&globals.Dialer, "tcp", remote_ip.String(),&tls.Config{InsecureSkipVerify: true}) - //conn, err := tls.Dial("tcp", remote_ip.String(),&tls.Config{InsecureSkipVerify: true}) if err != nil { logger.V(3).Error(err, "Dial failed", "endpoint", endpoint) - Peer_SetFail(remote_ip.String()) // update peer list as we see - return //nil, fmt.Errorf("Dial failed err %s", err.Error()) + Peer_SetFail(ParseIPNoError(remote_ip.String())) // update peer list as we see + conn.Close() + return //nil, fmt.Errorf("Dial failed err %s", err.Error()) } - tcpc := conn.(*net.TCPConn) - // detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl - // default on linux: 30 + 8 * 30 - // default on osx: 30 + 8 * 75 - tcpc.SetKeepAlive(true) - tcpc.SetKeepAlivePeriod(8 * time.Second) - tcpc.SetLinger(0) // discard any pending data + tunekcp(conn) // set tunings for low latency - //conn.SetKeepAlive(true) // set keep alive true - //conn.SetKeepAlivePeriod(10*time.Second) // keep alive every 10 secs - - // upgrade connection TO TLS ( tls.Dial does NOT support proxy) // TODO we need to choose fastest cipher here ( so both clients/servers are not loaded) - conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) - - process_connection(conn, remote_ip, false, sync_node) + conntls := tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) + process_outgoing_connection(conn, conntls, remote_ip, false, sync_node) - //Handle_Connection(conn, remote_ip, false, sync_node) // handle connection } // maintains a persistant connection to endpoint @@ -322,7 +412,7 @@ func maintain_connection_to_peers() { logger.Info("Min outgoing peers", "min-peers", Min_Peers) } - delay := time.NewTicker(time.Second) + delay := time.NewTicker(200 * time.Millisecond) for { select { @@ -339,7 +429,7 @@ func maintain_connection_to_peers() { } peer := find_peer_to_connect(1) - if peer != nil { + if peer != nil && !IsAddressConnected(ParseIPNoError(peer.Address)) { go connect_with_endpoint(peer.Address, false) } } @@ -347,6 +437,8 @@ func maintain_connection_to_peers() { func P2P_Server_v2() { + var accept_limiter = rate.NewLimiter(10.0, 40) // 10 incoming per sec, burst of 40 is okay + default_address := "0.0.0.0:0" // be default choose a random port if _, ok := globals.Arguments["--p2p-bind"]; ok && globals.Arguments["--p2p-bind"] != nil { addr, err := net.ResolveTCPAddr("tcp", globals.Arguments["--p2p-bind"].(string)) @@ -363,115 +455,193 @@ func P2P_Server_v2() { } } + srv := rpc2.NewServer() + srv.OnConnect(func(c *rpc2.Client) { + remote_addr_interface, _ := c.State.Get("addr") + remote_addr := remote_addr_interface.(net.Addr) + + conn_interface, _ := c.State.Get("conn") + conn := conn_interface.(net.Conn) + + tlsconn_interface, _ := c.State.Get("tlsconn") + tlsconn := tlsconn_interface.(net.Conn) + + connection := &Connection{Client: c, Conn: conn, ConnTls: tlsconn, Addr: remote_addr, State: HANDSHAKE_PENDING, Incoming: true} + connection.logger = logger.WithName("incoming").WithName(remote_addr.String()) + + c.State.Set("c", connection) // set pointer to connection + + //connection.logger.Info("connected OnConnect") + go func() { + time.Sleep(2 * time.Second) + connection.dispatch_test_handshake() + }() + + }) + + set_handlers(srv) + tlsconfig := &tls.Config{Certificates: []tls.Certificate{generate_random_tls_cert()}} //l, err := tls.Listen("tcp", default_address, tlsconfig) // listen as TLS server + _ = tlsconfig + + var masterkey = pbkdf2.Key(globals.Config.Network_ID.Bytes(), globals.Config.Network_ID.Bytes(), 1024, 32, sha1.New) + var blockcipher, _ = kcp.NewAESBlockCrypt(masterkey) + // listen to incoming tcp connections tls style - l, err := net.Listen("tcp", default_address) // listen as simple TCP server + l, err := kcp.ListenWithOptions(default_address, blockcipher, 10, 3) if err != nil { logger.Error(err, "Could not listen", "address", default_address) return } defer l.Close() - P2P_Port = int(l.Addr().(*net.TCPAddr).Port) - logger.Info("P2P is listening", "address", l.Addr().String()) + _, P2P_Port_str, _ := net.SplitHostPort(l.Addr().String()) + P2P_Port, _ = strconv.Atoi(P2P_Port_str) - // p2p is shutting down, close the listening socket - go func() { <-Exit_Event; l.Close() }() + logger.Info("P2P is listening", "address", l.Addr().String()) // A common pattern is to start a loop to continously accept connections for { - conn, err := l.Accept() //accept connections using Listener.Accept() + conn, err := l.AcceptKCP() //accept connections using Listener.Accept() if err != nil { select { case <-Exit_Event: + l.Close() // p2p is shutting down, close the listening socket return default: } - logger.Error(err, "Err while accepting incoming connection") + logger.V(1).Error(err, "Err while accepting incoming connection") continue } - raddr := conn.RemoteAddr().(*net.TCPAddr) - //if incoming IP is banned, disconnect now - if IsAddressInBanList(raddr.IP.String()) { - logger.Info("Incoming IP is banned, disconnecting now", "IP", raddr.IP.String()) + if !accept_limiter.Allow() { // if rate limiter allows, then only add else drop the connection conn.Close() - } else { + continue + } + + raddr := conn.RemoteAddr().(*net.UDPAddr) - tcpc := conn.(*net.TCPConn) - // detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl - // default on linux: 30 + 8 * 30 - // default on osx: 30 + 8 * 75 - tcpc.SetKeepAlive(true) - tcpc.SetKeepAlivePeriod(8 * time.Second) - tcpc.SetLinger(0) // discard any pending data + backoff_mutex.Lock() + backoff[ParseIPNoError(raddr.String())] = time.Now().Unix() + globals.Global_Random.Int63n(200) // random backing of upto 200 secs + backoff_mutex.Unlock() - tlsconn := tls.Server(conn, tlsconfig) - go process_connection(tlsconn, raddr, true, false) // handle connection in a different go routine + logger.V(3).Info("accepting incoming connection", "raddr", raddr.String()) + + if IsAddressConnected(ParseIPNoError(raddr.String())) { + logger.V(4).Info("incoming address is already connected", "ip", raddr.String()) + conn.Close() + + } else if IsAddressInBanList(ParseIPNoError(raddr.IP.String())) { //if incoming IP is banned, disconnect now + logger.V(2).Info("Incoming IP is banned, disconnecting now", "IP", raddr.IP.String()) + conn.Close() } + + tunekcp(conn) // tuning paramters for local stack + tlsconn := tls.Server(conn, tlsconfig) + state := rpc2.NewState() + state.Set("addr", raddr) + state.Set("conn", conn) + state.Set("tlsconn", tlsconn) + + go srv.ServeCodecWithState(NewCBORCodec(tlsconn), state) + } } func handle_connection_panic(c *Connection) { + defer globals.Recover(2) if r := recover(); r != nil { logger.V(2).Error(nil, "Recovered while handling connection", "r", r, "stack", debug.Stack()) c.exit() } } -func process_connection(conn net.Conn, remote_addr *net.TCPAddr, incoming, sync_node bool) { - defer globals.Recover(2) +func set_handler(base interface{}, methodname string, handler interface{}) { + switch o := base.(type) { + case *rpc2.Client: + o.Handle(methodname, handler) + //fmt.Printf("setting client handler %s\n", methodname) + case *rpc2.Server: + o.Handle(methodname, handler) + //fmt.Printf("setting server handler %s\n", methodname) + default: + panic(fmt.Sprintf("object cannot handle such handler %T", base)) + + } +} - var rconn *RPC_Connection - var err error - if incoming { - rconn, err = wait_stream_creation_server_side(conn) // do server side processing +func getc(client *rpc2.Client) *Connection { + if ci, found := client.State.Get("c"); found { + return ci.(*Connection) } else { - rconn, err = stream_creation_client_side(conn) // do client side processing + panic("no connection attached") + return nil } - if err == nil { +} - var RPCSERVER = rpc.NewServer() - c := &Connection{RConn: rconn, Addr: remote_addr, State: HANDSHAKE_PENDING, Incoming: incoming, SyncNode: sync_node} - RPCSERVER.RegisterName("Peer", c) // register the handlers +// we need the following RPCS to work +func set_handlers(o interface{}) { + set_handler(o, "Peer.Handshake", func(client *rpc2.Client, args Handshake_Struct, reply *Handshake_Struct) error { + return getc(client).Handshake(args, reply) + }) + set_handler(o, "Peer.Chain", func(client *rpc2.Client, args Chain_Request_Struct, reply *Chain_Response_Struct) error { + return getc(client).Chain(args, reply) + }) + set_handler(o, "Peer.ChangeSet", func(client *rpc2.Client, args ChangeList, reply *Changes) error { + return getc(client).ChangeSet(args, reply) + }) + set_handler(o, "Peer.NotifyINV", func(client *rpc2.Client, args ObjectList, reply *Dummy) error { + return getc(client).NotifyINV(args, reply) + }) + set_handler(o, "Peer.GetObject", func(client *rpc2.Client, args ObjectList, reply *Objects) error { + return getc(client).GetObject(args, reply) + }) + set_handler(o, "Peer.TreeSection", func(client *rpc2.Client, args Request_Tree_Section_Struct, reply *Response_Tree_Section_Struct) error { + return getc(client).TreeSection(args, reply) + }) + set_handler(o, "Peer.NotifyMiniBlock", func(client *rpc2.Client, args Objects, reply *Dummy) error { + return getc(client).NotifyMiniBlock(args, reply) + }) + set_handler(o, "Peer.Ping", func(client *rpc2.Client, args Dummy, reply *Dummy) error { + return getc(client).Ping(args, reply) + }) - if incoming { - c.logger = logger.WithName("incoming").WithName(remote_addr.String()) - } else { - c.logger = logger.WithName("outgoing").WithName(remote_addr.String()) - } - go func() { - defer func() { - if r := recover(); r != nil { - logger.V(1).Error(nil, "Recovered while handling connection", "r", r, "stack", debug.Stack()) - conn.Close() - } - }() - //RPCSERVER.ServeConn(rconn.ServerConn) // start single threaded rpc server with GOB encoding - RPCSERVER.ServeCodec(NewCBORServerCodec(rconn.ServerConn)) // use CBOR encoding on rpc - }() +} + +func process_outgoing_connection(conn net.Conn, tlsconn net.Conn, remote_addr net.Addr, incoming, sync_node bool) { + defer globals.Recover(0) + + client := rpc2.NewClientWithCodec(NewCBORCodec(tlsconn)) + + c := &Connection{Client: client, Conn: conn, ConnTls: tlsconn, Addr: remote_addr, State: HANDSHAKE_PENDING, Incoming: incoming, SyncNode: sync_node} + defer c.exit() + c.logger = logger.WithName("outgoing").WithName(remote_addr.String()) + set_handlers(client) + + client.State = rpc2.NewState() + client.State.Set("c", c) + go func() { + time.Sleep(2 * time.Second) c.dispatch_test_handshake() + }() - <-rconn.Session.CloseChan() - Connection_Delete(c) - //fmt.Printf("closing connection status err: %s\n",err) - } - conn.Close() + // c.logger.V(4).Info("client running loop") + client.Run() // see the original + c.logger.V(4).Info("process_connection finished") } // shutdown the p2p component func P2P_Shutdown() { - close(Exit_Event) // send signal to all connections to exit - save_peer_list() // save peer list - save_ban_list() // save ban list + //close(Exit_Event) // send signal to all connections to exit + save_peer_list() // save peer list + save_ban_list() // save ban list // TODO we must wait for connections to kill themselves - time.Sleep(1 * time.Second) logger.Info("P2P Shutdown") atomic.AddUint32(&globals.Subsystem_Active, ^uint32(0)) // this decrement 1 fom subsystem @@ -538,28 +708,21 @@ func generate_random_tls_cert() tls.Certificate { return tlsCert } -/* -// register all the handlers -func register_handlers(){ - arpc.DefaultHandler.Handle("/handshake",Handshake_Handler) - arpc.DefaultHandler.Handle("/active",func (ctx *arpc.Context) { // set the connection active - if c,ok := ctx.Client.Get("connection");ok { - connection := c.(*Connection) - atomic.StoreUint32(&connection.State, ACTIVE) - }} ) - - arpc.DefaultHandler.HandleConnected(OnConnected_Handler) // all incoming connections will first processed here -arpc.DefaultHandler.HandleDisconnected(OnDisconnected_Handler) // all disconnected -} - +func ParseIP(s string) (string, error) { + ip, _, err := net.SplitHostPort(s) + if err == nil { + return ip, nil + } + ip2 := net.ParseIP(s) + if ip2 == nil { + return "", fmt.Errorf("invalid IP") + } -// triggers when new clients connect and -func OnConnected_Handler(c *arpc.Client){ - dispatch_test_handshake(c, c.Conn.RemoteAddr().(*net.TCPAddr) ,true,false) // client connected we must handshake + return ip2.String(), nil } -func OnDisconnected_Handler(c *arpc.Client){ - c.Stop() +func ParseIPNoError(s string) string { + ip, _ := ParseIP(s) + return ip } -*/ diff --git a/p2p/peer_pool.go b/p2p/peer_pool.go index ae23cd15..a7bb78a1 100644 --- a/p2p/peer_pool.go +++ b/p2p/peer_pool.go @@ -121,14 +121,17 @@ func clean_up() { peer_mutex.Lock() defer peer_mutex.Unlock() for k, v := range peer_map { - if v.FailCount >= 8 { // roughly 16 tries, 18 hrs before we discard the peer + if IsAddressConnected(ParseIPNoError(v.Address)) { + continue + } + if v.FailCount >= 8 { // roughly 8 tries before we discard the peer delete(peer_map, k) } if v.LastConnected == 0 { // if never connected, purge the peer delete(peer_map, k) } - if uint64(time.Now().UTC().Unix()) > (v.LastConnected + 42000) { // purge all peers which were not connected in + if uint64(time.Now().UTC().Unix()) > (v.LastConnected + 3600) { // purge all peers which were not connected in delete(peer_map, k) } } @@ -139,7 +142,7 @@ func IsPeerInList(address string) bool { peer_mutex.Lock() defer peer_mutex.Unlock() - if _, ok := peer_map[address]; ok { + if _, ok := peer_map[ParseIPNoError(address)]; ok { return true } return false @@ -148,7 +151,7 @@ func GetPeerInList(address string) *Peer { peer_mutex.Lock() defer peer_mutex.Unlock() - if v, ok := peer_map[address]; ok { + if v, ok := peer_map[ParseIPNoError(address)]; ok { return v } return nil @@ -166,20 +169,20 @@ func Peer_Add(p *Peer) { } - if v, ok := peer_map[p.Address]; ok { + if v, ok := peer_map[ParseIPNoError(p.Address)]; ok { v.Lock() // logger.Infof("Peer already in list adding good count") v.GoodCount++ v.Unlock() } else { // logger.Infof("Peer adding to list") - peer_map[p.Address] = p + peer_map[ParseIPNoError(p.Address)] = p } } // a peer marked as fail, will only be connected based on exponential back-off based on powers of 2 func Peer_SetFail(address string) { - p := GetPeerInList(address) + p := GetPeerInList(ParseIPNoError(address)) if p == nil { return } @@ -193,7 +196,7 @@ func Peer_SetFail(address string) { // we will only distribute peers which have been successfully connected by us func Peer_SetSuccess(address string) { //logger.Infof("Setting peer as success") - p := GetPeerInList(address) + p := GetPeerInList(ParseIPNoError(address)) if p == nil { return } @@ -233,7 +236,7 @@ func Peer_EnableBan(address string) (err error){ func Peer_Delete(p *Peer) { peer_mutex.Lock() defer peer_mutex.Unlock() - delete(peer_map, p.Address) + delete(peer_map, ParseIPNoError(p.Address)) } // prints all the connection info to screen @@ -258,7 +261,7 @@ func PeerList_Print() { for i := range list { connected := "" - if IsAddressConnected(list[i].Address) { + if IsAddressConnected(ParseIPNoError(list[i].Address)) { connected = "ACTIVE" } fmt.Printf("%-22s %-6s %4d %5d \n", list[i].Address, connected, list[i].GoodCount, list[i].FailCount) @@ -269,7 +272,7 @@ func PeerList_Print() { } -// this function return peer count which have successful handshake +// this function return peer count which are in our list func Peer_Counts() (Count uint64) { peer_mutex.Lock() defer peer_mutex.Unlock() @@ -289,7 +292,7 @@ func find_peer_to_connect(version int) *Peer { for _, v := range peer_map { if uint64(time.Now().Unix()) > v.BlacklistBefore && // if ip is blacklisted skip it uint64(time.Now().Unix()) > v.ConnectAfter && - !IsAddressConnected(v.Address) && v.Whitelist && !IsAddressInBanList(v.Address) { + !IsAddressConnected(ParseIPNoError(v.Address)) && v.Whitelist && !IsAddressInBanList(ParseIPNoError(v.Address)) { v.ConnectAfter = uint64(time.Now().UTC().Unix()) + 10 // minimum 10 secs gap return v } @@ -298,7 +301,7 @@ func find_peer_to_connect(version int) *Peer { for _, v := range peer_map { if uint64(time.Now().Unix()) > v.BlacklistBefore && // if ip is blacklisted skip it uint64(time.Now().Unix()) > v.ConnectAfter && - !IsAddressConnected(v.Address) && !v.Whitelist && !IsAddressInBanList(v.Address) { + !IsAddressConnected(ParseIPNoError(v.Address)) && !v.Whitelist && !IsAddressInBanList(ParseIPNoError(v.Address)) { v.ConnectAfter = uint64(time.Now().UTC().Unix()) + 10 // minimum 10 secs gap return v } diff --git a/p2p/rpc_cbor_codec.go b/p2p/rpc_cbor_codec.go index 3d38f00f..76f384ed 100644 --- a/p2p/rpc_cbor_codec.go +++ b/p2p/rpc_cbor_codec.go @@ -4,30 +4,30 @@ package p2p import "fmt" import "io" import "net" -import "net/rpc" -import "bufio" +import "sync" +import "time" +import "github.com/cenkalti/rpc2" import "encoding/binary" import "github.com/fxamacker/cbor/v2" import "github.com/deroproject/derohe/config" // only used get constants such as max data per frame -// used to represent net/rpc structs -type Request struct { - ServiceMethod string `cbor:"M"` // format: "Service.Method" - Seq uint64 `cbor:"S"` // sequence number chosen by client +// it processes both +type RequestResponse struct { + Method string `cbor:"M"` // format: "Service.Method" + Seq uint64 `cbor:"S"` // echoes that of the request + Error string `cbor:"E"` // error, if any. } -type Response struct { - ServiceMethod string `cbor:"M"` // echoes that of the Request - Seq uint64 `cbor:"S"` // echoes that of the request - Error string `cbor:"E"` // error, if any. -} +const READ_TIMEOUT = 20 * time.Second +const WRITE_TIMEOUT = 20 * time.Second // reads our data, length prefix blocks -func Read_Data_Frame(r io.Reader, obj interface{}) error { +func Read_Data_Frame(r net.Conn, obj interface{}) error { var frame_length_buf [4]byte //connection.set_timeout() + r.SetReadDeadline(time.Now().Add(READ_TIMEOUT)) nbyte, err := io.ReadFull(r, frame_length_buf[:]) if err != nil { return err @@ -58,7 +58,7 @@ func Read_Data_Frame(r io.Reader, obj interface{}) error { } // reads our data, length prefix blocks -func Write_Data_Frame(w io.Writer, obj interface{}) error { +func Write_Data_Frame(w net.Conn, obj interface{}) error { var frame_length_buf [4]byte data_bytes, err := cbor.Marshal(obj) if err != nil { @@ -66,6 +66,7 @@ func Write_Data_Frame(w io.Writer, obj interface{}) error { } binary.LittleEndian.PutUint32(frame_length_buf[:], uint32(len(data_bytes))) + w.SetWriteDeadline(time.Now().Add(WRITE_TIMEOUT)) if _, err = w.Write(frame_length_buf[:]); err != nil { return err } @@ -76,112 +77,104 @@ func Write_Data_Frame(w io.Writer, obj interface{}) error { // ClientCodec implements the rpc.ClientCodec interface for generic golang objects. type ClientCodec struct { - r *bufio.Reader - w io.WriteCloser + r net.Conn + sync.Mutex } -// ServerCodec implements the rpc.ServerCodec interface for generic protobufs. -type ServerCodec ClientCodec - // NewClientCodec returns a ClientCodec for communicating with the ServerCodec // on the other end of the conn. -func NewCBORClientCodec(conn net.Conn) *ClientCodec { - return &ClientCodec{bufio.NewReader(conn), conn} -} - -// NewServerCodec returns a ServerCodec that communicates with the ClientCodec -// on the other end of the given conn. -func NewCBORServerCodec(conn net.Conn) *ServerCodec { - return &ServerCodec{bufio.NewReader(conn), conn} -} - -// WriteRequest writes the 4 byte length from the connection and encodes that many -// subsequent bytes into the given object. -func (c *ClientCodec) WriteRequest(req *rpc.Request, obj interface{}) error { - // Write the header - header := Request{ServiceMethod: req.ServiceMethod, Seq: req.Seq} - if err := Write_Data_Frame(c.w, header); err != nil { - return err - } - return Write_Data_Frame(c.w, obj) +// to support deadlines we use net.conn +func NewCBORCodec(conn net.Conn) *ClientCodec { + return &ClientCodec{r: conn} } // ReadResponseHeader reads a 4 byte length from the connection and decodes that many // subsequent bytes into the given object, decodes it, and stores the fields // in the given request. -func (c *ClientCodec) ReadResponseHeader(resp *rpc.Response) error { - var header Response +func (c *ClientCodec) ReadResponseHeader(resp *rpc2.Response) error { + var header RequestResponse if err := Read_Data_Frame(c.r, &header); err != nil { return err } - if header.ServiceMethod == "" { - return fmt.Errorf("header missing method: %s", "no ServiceMethod") - } - resp.ServiceMethod = header.ServiceMethod + //if header.Method == "" { + // return fmt.Errorf("header missing method: %s", "no Method") + //} + //resp.Method = header.Method resp.Seq = header.Seq resp.Error = header.Error return nil } -// ReadResponseBody reads a 4 byte length from the connection and decodes that many -// subsequent bytes into the given object (which should be a pointer to a -// struct). -func (c *ClientCodec) ReadResponseBody(obj interface{}) error { - if obj == nil { - return nil - } - return Read_Data_Frame(c.r, obj) -} - // Close closes the underlying connection. func (c *ClientCodec) Close() error { - return c.w.Close() -} - -// Close closes the underlying connection. -func (c *ServerCodec) Close() error { - return c.w.Close() + return c.r.Close() } // ReadRequestHeader reads the header (which is prefixed by a 4 byte lil endian length // indicating its size) from the connection, decodes it, and stores the fields // in the given request. -func (s *ServerCodec) ReadRequestHeader(req *rpc.Request) error { - var header Request +func (s *ClientCodec) ReadHeader(req *rpc2.Request, resp *rpc2.Response) error { + var header RequestResponse if err := Read_Data_Frame(s.r, &header); err != nil { return err } - if header.ServiceMethod == "" { - return fmt.Errorf("header missing method: %s", "empty ServiceMethod") + + if header.Method != "" { + req.Seq = header.Seq + req.Method = header.Method + } else { + resp.Seq = header.Seq + resp.Error = header.Error } - req.ServiceMethod = header.ServiceMethod - req.Seq = header.Seq return nil } // ReadRequestBody reads a 4 byte length from the connection and decodes that many // subsequent bytes into the object -func (s *ServerCodec) ReadRequestBody(obj interface{}) error { +func (s *ClientCodec) ReadRequestBody(obj interface{}) error { if obj == nil { return nil } return Read_Data_Frame(s.r, obj) } +// ReadResponseBody reads a 4 byte length from the connection and decodes that many +// subsequent bytes into the given object (which should be a pointer to a +// struct). +func (c *ClientCodec) ReadResponseBody(obj interface{}) error { + if obj == nil { + return nil + } + return Read_Data_Frame(c.r, obj) +} + +// WriteRequest writes the 4 byte length from the connection and encodes that many +// subsequent bytes into the given object. +func (c *ClientCodec) WriteRequest(req *rpc2.Request, obj interface{}) error { + c.Lock() + defer c.Unlock() + + header := RequestResponse{Method: req.Method, Seq: req.Seq} + if err := Write_Data_Frame(c.r, header); err != nil { + return err + } + return Write_Data_Frame(c.r, obj) +} + // WriteResponse writes the appropriate header. If // the response was invalid, the size of the body of the resp is reported as // having size zero and is not sent. -func (s *ServerCodec) WriteResponse(resp *rpc.Response, obj interface{}) error { - // Write the header - header := Response{ServiceMethod: resp.ServiceMethod, Seq: resp.Seq, Error: resp.Error} - - if err := Write_Data_Frame(s.w, header); err != nil { +func (c *ClientCodec) WriteResponse(resp *rpc2.Response, obj interface{}) error { + c.Lock() + defer c.Unlock() + header := RequestResponse{Seq: resp.Seq, Error: resp.Error} + if err := Write_Data_Frame(c.r, header); err != nil { return err } if resp.Error == "" { // only write response object if error is nil - return Write_Data_Frame(s.w, obj) + return Write_Data_Frame(c.r, obj) } return nil diff --git a/p2p/rpc_handshake.go b/p2p/rpc_handshake.go index 66191414..cb1dd7fa 100644 --- a/p2p/rpc_handshake.go +++ b/p2p/rpc_handshake.go @@ -17,13 +17,13 @@ package p2p import "fmt" +import "net" import "bytes" +import "context" import "sync/atomic" import "time" -import "github.com/paulbellamy/ratecounter" - import "github.com/deroproject/derohe/config" import "github.com/deroproject/derohe/globals" @@ -56,7 +56,9 @@ func (connection *Connection) dispatch_test_handshake() { var request, response Handshake_Struct request.Fill() - if err := connection.RConn.Client.Call("Peer.Handshake", request, &response); err != nil { + ctx, _ := context.WithTimeout(context.Background(), 4*time.Second) + if err := connection.Client.CallWithContext(ctx, "Peer.Handshake", request, &response); err != nil { + connection.logger.V(2).Error(err, "cannot handshake") connection.exit() return } @@ -66,15 +68,10 @@ func (connection *Connection) dispatch_test_handshake() { connection.exit() return } - - connection.request_time.Store(time.Now()) - connection.SpeedIn = ratecounter.NewRateCounter(60 * time.Second) - connection.SpeedOut = ratecounter.NewRateCounter(60 * time.Second) - connection.update(&response.Common) // update common information - - if !connection.Incoming { // setup success - Peer_SetSuccess(connection.Addr.String()) + if !Connection_Add(connection) { // add connection to pool + connection.exit() + return } if len(response.ProtocolVersion) < 128 { @@ -98,24 +95,15 @@ func (connection *Connection) dispatch_test_handshake() { if connection.Port != 0 && connection.Port <= 65535 { // peer is saying it has an open port, handshake is success so add peer var p Peer - if connection.Addr.IP.To4() != nil { // if ipv4 - p.Address = fmt.Sprintf("%s:%d", connection.Addr.IP.String(), connection.Port) + if net.ParseIP(Address(connection)).To4() != nil { // if ipv4 + p.Address = fmt.Sprintf("%s:%d", Address(connection), connection.Port) } else { // if ipv6 - p.Address = fmt.Sprintf("[%s]:%d", connection.Addr.IP.String(), connection.Port) + p.Address = fmt.Sprintf("[%s]:%d", Address(connection), connection.Port) } p.ID = connection.Peer_ID p.LastConnected = uint64(time.Now().UTC().Unix()) - // TODO we should add any flags here if necessary, but they are not - // required, since a peer can only be used if connected and if connected - // we already have a truly synced view - for _, k := range response.Flags { - switch k { - //case FLAG_MINER:p.Miner = true - default: - } - } Peer_Add(&p) } @@ -127,25 +115,7 @@ func (connection *Connection) dispatch_test_handshake() { } } - Connection_Add(connection) // add connection to pool - - // mark active - var r Dummy - fill_common(&r.Common) // fill common info - if err := connection.RConn.Client.Call("Peer.Active", r, &r); err != nil { - connection.exit() - return - } - -} - -// mark connection active -func (c *Connection) Active(req Dummy, dummy *Dummy) error { - defer handle_connection_panic(c) - c.update(&req.Common) // update common information - atomic.StoreUint32(&c.State, ACTIVE) - fill_common(&dummy.Common) // fill common info - return nil + atomic.StoreUint32(&connection.State, ACTIVE) } // used to ping pong diff --git a/p2p/rpc_notifications.go b/p2p/rpc_notifications.go index 947ed7d4..c1a7dab4 100644 --- a/p2p/rpc_notifications.go +++ b/p2p/rpc_notifications.go @@ -91,7 +91,7 @@ func (c *Connection) NotifyINV(request ObjectList, response *Dummy) (err error) if dirty { // request inventory only if we want it var oresponse Objects fill_common(&need.Common) // fill common info - if err = c.RConn.Client.Call("Peer.GetObject", need, &oresponse); err != nil { + if err = c.Client.Call("Peer.GetObject", need, &oresponse); err != nil { c.logger.V(2).Error(err, "Call failed GetObject", "need_objects", need) c.exit() return @@ -113,8 +113,8 @@ func (c *Connection) NotifyINV(request ObjectList, response *Dummy) (err error) func (c *Connection) NotifyMiniBlock(request Objects, response *Dummy) (err error) { defer handle_connection_panic(c) - if len(request.MiniBlocks) != 1 { - err = fmt.Errorf("Notify Block can notify single block") + if len(request.MiniBlocks) >= 5 { + err = fmt.Errorf("Notify Block can notify max 5 miniblocks") c.logger.V(3).Error(err, "Should be banned") c.exit() return err @@ -122,20 +122,20 @@ func (c *Connection) NotifyMiniBlock(request Objects, response *Dummy) (err erro fill_common_T1(&request.Common) c.update(&request.Common) // update common information - var mbl_arrays [][]byte - if len(c.previous_mbl) > 0 { - mbl_arrays = append(mbl_arrays, c.previous_mbl) - } - mbl_arrays = append(mbl_arrays, request.MiniBlocks...) + var mbls []block.MiniBlock - for i := range mbl_arrays { + for i := range request.MiniBlocks { var mbl block.MiniBlock - var ok bool - - if err = mbl.Deserialize(mbl_arrays[i]); err != nil { + if err = mbl.Deserialize(request.MiniBlocks[i]); err != nil { return err } + mbls = append(mbls, mbl) + } + var valid_found bool + + for _, mbl := range mbls { + var ok bool if mbl.Timestamp > uint64(globals.Time().UTC().UnixMilli())+50 { // 50 ms passing allowed return errormsg.ErrInvalidTimestamp } @@ -153,8 +153,7 @@ func (c *Connection) NotifyMiniBlock(request Objects, response *Dummy) (err erro // first check whether the incoming minblock can be added to sub chains if !chain.MiniBlocks.IsConnected(mbl) { - c.previous_mbl = mbl.Serialize() - c.logger.V(3).Error(err, "Disconnected miniblock","mbl",mbl.String()) + c.logger.V(3).Error(err, "Disconnected miniblock", "mbl", mbl.String()) //return fmt.Errorf("Disconnected miniblock") continue } @@ -213,9 +212,13 @@ func (c *Connection) NotifyMiniBlock(request Objects, response *Dummy) (err erro if err, ok = chain.InsertMiniBlock(mbl); !ok { return err } else { // rebroadcast miniblock - defer broadcast_MiniBlock(mbl, c.Peer_ID, request.Sent) // do not send back to the original peer + valid_found = true } } + if valid_found { + Peer_SetSuccess(c.Addr.String()) + broadcast_MiniBlock(mbls, c.Peer_ID, request.Sent) // do not send back to the original peer + } fill_common(&response.Common) // fill common info fill_common_T0T1T2(&request.Common, &response.Common) // fill time related information return nil diff --git a/p2p/timecheck.go b/p2p/timecheck.go index b0ac2a44..e0b1c819 100644 --- a/p2p/timecheck.go +++ b/p2p/timecheck.go @@ -55,7 +55,7 @@ func time_check_routine() { server := timeservers[random.Int()%len(timeservers)] if response, err := ntp.Query(server); err != nil { - logger.V(2).Error(err, "error while querying time", "server", server) + //logger.V(2).Error(err, "error while querying time", "server", server) } else if response.Validate() == nil { if response.ClockOffset.Seconds() > -.05 && response.ClockOffset.Seconds() < .05 { @@ -86,6 +86,7 @@ func time_check_routine() { if response.ClockOffset.Seconds() > -1.0 && response.ClockOffset.Seconds() < 1.0 { // chrony can maintain upto 5 ms, ntps can maintain upto 10 timeinsync = true } else { + timeinsync = false logger.V(1).Error(nil, "Your system time deviation is more than 1 secs (%s)."+ "\nYou may experience chain sync issues and/or other side-effects."+ "\nIf you are mining, your blocks may get rejected."+ diff --git a/vendor/github.com/cenkalti/hub/.gitignore b/vendor/github.com/cenkalti/hub/.gitignore new file mode 100644 index 00000000..00268614 --- /dev/null +++ b/vendor/github.com/cenkalti/hub/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/vendor/github.com/cenkalti/hub/.travis.yml b/vendor/github.com/cenkalti/hub/.travis.yml new file mode 100644 index 00000000..d8cecb0d --- /dev/null +++ b/vendor/github.com/cenkalti/hub/.travis.yml @@ -0,0 +1,5 @@ +language: go +go: 1.13 +arch: + - amd64 + - ppc64le diff --git a/vendor/github.com/cenkalti/hub/LICENSE b/vendor/github.com/cenkalti/hub/LICENSE new file mode 100644 index 00000000..89b81799 --- /dev/null +++ b/vendor/github.com/cenkalti/hub/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2014 Cenk Altı + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/cenkalti/hub/README.md b/vendor/github.com/cenkalti/hub/README.md new file mode 100644 index 00000000..d3f21181 --- /dev/null +++ b/vendor/github.com/cenkalti/hub/README.md @@ -0,0 +1,5 @@ +hub +=== + +[![GoDoc](https://godoc.org/github.com/cenkalti/hub?status.png)](https://godoc.org/github.com/cenkalti/hub) +[![Build Status](https://travis-ci.org/cenkalti/hub.png)](https://travis-ci.org/cenkalti/hub) diff --git a/vendor/github.com/cenkalti/hub/example_test.go b/vendor/github.com/cenkalti/hub/example_test.go new file mode 100644 index 00000000..ea6fb5f1 --- /dev/null +++ b/vendor/github.com/cenkalti/hub/example_test.go @@ -0,0 +1,32 @@ +package hub_test + +import ( + "fmt" + + "github.com/cenk/hub" +) + +// Different event kinds +const ( + happenedA hub.Kind = iota + happenedB + happenedC +) + +// Our custom event type +type EventA struct { + arg1, arg2 int +} + +// Implement hub.Event interface +func (e EventA) Kind() hub.Kind { return happenedA } + +func Example() { + hub.Subscribe(happenedA, func(e hub.Event) { + a := e.(EventA) // Cast to concrete type + fmt.Println(a.arg1 + a.arg2) + }) + + hub.Publish(EventA{2, 3}) + // Output: 5 +} diff --git a/vendor/github.com/cenkalti/hub/hub.go b/vendor/github.com/cenkalti/hub/hub.go new file mode 100644 index 00000000..24c5efa8 --- /dev/null +++ b/vendor/github.com/cenkalti/hub/hub.go @@ -0,0 +1,82 @@ +// Package hub provides a simple event dispatcher for publish/subscribe pattern. +package hub + +import "sync" + +type Kind int + +// Event is an interface for published events. +type Event interface { + Kind() Kind +} + +// Hub is an event dispatcher, publishes events to the subscribers +// which are subscribed for a specific event type. +// Optimized for publish calls. +// The handlers may be called in order different than they are registered. +type Hub struct { + subscribers map[Kind][]handler + m sync.RWMutex + seq uint64 +} + +type handler struct { + f func(Event) + id uint64 +} + +// Subscribe registers f for the event of a specific kind. +func (h *Hub) Subscribe(kind Kind, f func(Event)) (cancel func()) { + var cancelled bool + h.m.Lock() + h.seq++ + id := h.seq + if h.subscribers == nil { + h.subscribers = make(map[Kind][]handler) + } + h.subscribers[kind] = append(h.subscribers[kind], handler{id: id, f: f}) + h.m.Unlock() + return func() { + h.m.Lock() + if cancelled { + h.m.Unlock() + return + } + cancelled = true + a := h.subscribers[kind] + for i, f := range a { + if f.id == id { + a[i], h.subscribers[kind] = a[len(a)-1], a[:len(a)-1] + break + } + } + if len(a) == 0 { + delete(h.subscribers, kind) + } + h.m.Unlock() + } +} + +// Publish an event to the subscribers. +func (h *Hub) Publish(e Event) { + h.m.RLock() + if handlers, ok := h.subscribers[e.Kind()]; ok { + for _, h := range handlers { + h.f(e) + } + } + h.m.RUnlock() +} + +// DefaultHub is the default Hub used by Publish and Subscribe. +var DefaultHub Hub + +// Subscribe registers f for the event of a specific kind in the DefaultHub. +func Subscribe(kind Kind, f func(Event)) (cancel func()) { + return DefaultHub.Subscribe(kind, f) +} + +// Publish an event to the subscribers in DefaultHub. +func Publish(e Event) { + DefaultHub.Publish(e) +} diff --git a/vendor/github.com/cenkalti/hub/hub_test.go b/vendor/github.com/cenkalti/hub/hub_test.go new file mode 100644 index 00000000..906482bb --- /dev/null +++ b/vendor/github.com/cenkalti/hub/hub_test.go @@ -0,0 +1,40 @@ +package hub + +import "testing" + +const testKind Kind = 1 +const testValue = "foo" + +type testEvent string + +func (e testEvent) Kind() Kind { + return testKind +} + +func TestPubSub(t *testing.T) { + var h Hub + var s string + + h.Subscribe(testKind, func(e Event) { s = string(e.(testEvent)) }) + h.Publish(testEvent(testValue)) + + if s != testValue { + t.Errorf("invalid value: %s", s) + } +} + +func TestCancel(t *testing.T) { + var h Hub + var called int + var f = func(e Event) { called += 1 } + + _ = h.Subscribe(testKind, f) + cancel := h.Subscribe(testKind, f) + h.Publish(testEvent(testValue)) // 2 calls to f + cancel() + h.Publish(testEvent(testValue)) // 1 call to f + + if called != 3 { + t.Errorf("unexpected call count: %d", called) + } +} diff --git a/vendor/github.com/cenkalti/rpc2/.gitignore b/vendor/github.com/cenkalti/rpc2/.gitignore new file mode 100644 index 00000000..83656241 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/.gitignore @@ -0,0 +1,23 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test diff --git a/vendor/github.com/cenkalti/rpc2/.travis.yml b/vendor/github.com/cenkalti/rpc2/.travis.yml new file mode 100644 index 00000000..ae8233c2 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.15 + - tip + +arch: + - amd64 + - ppc64le diff --git a/vendor/github.com/cenkalti/rpc2/LICENSE b/vendor/github.com/cenkalti/rpc2/LICENSE new file mode 100644 index 00000000..d565b1b1 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 Cenk Altı + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/vendor/github.com/cenkalti/rpc2/README.md b/vendor/github.com/cenkalti/rpc2/README.md new file mode 100644 index 00000000..3dffd26e --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/README.md @@ -0,0 +1,82 @@ +rpc2 +==== + +[![GoDoc](https://godoc.org/github.com/cenkalti/rpc2?status.png)](https://godoc.org/github.com/cenkalti/rpc2) +[![Build Status](https://travis-ci.org/cenkalti/rpc2.png)](https://travis-ci.org/cenkalti/rpc2) + +rpc2 is a fork of net/rpc package in the standard library. +The main goal is to add bi-directional support to calls. +That means server can call the methods of client. +This is not possible with net/rpc package. +In order to do this it adds a `*Client` argument to method signatures. + +Install +-------- + + go get github.com/cenkalti/rpc2 + +Example server +--------------- + +```go +package main + +import ( + "fmt" + "net" + + "github.com/cenkalti/rpc2" +) + +type Args struct{ A, B int } +type Reply int + +func main() { + srv := rpc2.NewServer() + srv.Handle("add", func(client *rpc2.Client, args *Args, reply *Reply) error { + + // Reversed call (server to client) + var rep Reply + client.Call("mult", Args{2, 3}, &rep) + fmt.Println("mult result:", rep) + + *reply = Reply(args.A + args.B) + return nil + }) + + lis, _ := net.Listen("tcp", "127.0.0.1:5000") + srv.Accept(lis) +} +``` + +Example Client +--------------- + +```go +package main + +import ( + "fmt" + "net" + + "github.com/cenkalti/rpc2" +) + +type Args struct{ A, B int } +type Reply int + +func main() { + conn, _ := net.Dial("tcp", "127.0.0.1:5000") + + clt := rpc2.NewClient(conn) + clt.Handle("mult", func(client *rpc2.Client, args *Args, reply *Reply) error { + *reply = Reply(args.A * args.B) + return nil + }) + go clt.Run() + + var rep Reply + clt.Call("add", Args{1, 2}, &rep) + fmt.Println("add result:", rep) +} +``` diff --git a/vendor/github.com/cenkalti/rpc2/client.go b/vendor/github.com/cenkalti/rpc2/client.go new file mode 100644 index 00000000..cc995697 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/client.go @@ -0,0 +1,364 @@ +// Package rpc2 provides bi-directional RPC client and server similar to net/rpc. +package rpc2 + +import ( + "context" + "errors" + "io" + "log" + "reflect" + "sync" +) + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + mutex sync.Mutex // protects pending, seq, request + sending sync.Mutex + request Request // temp area used in send() + seq uint64 + pending map[uint64]*Call + closing bool + shutdown bool + server bool + codec Codec + handlers map[string]*handler + disconnect chan struct{} + State *State // additional information to associate with client + blocking bool // whether to block request handling +} + +// NewClient returns a new Client to handle requests to the +// set of services at the other end of the connection. +// It adds a buffer to the write side of the connection so +// the header and payload are sent as a unit. +func NewClient(conn io.ReadWriteCloser) *Client { + return NewClientWithCodec(NewGobCodec(conn)) +} + +// NewClientWithCodec is like NewClient but uses the specified +// codec to encode requests and decode responses. +func NewClientWithCodec(codec Codec) *Client { + return &Client{ + codec: codec, + pending: make(map[uint64]*Call), + handlers: make(map[string]*handler), + disconnect: make(chan struct{}), + seq: 1, // 0 means notification. + } +} + +// SetBlocking puts the client in blocking mode. +// In blocking mode, received requests are processes synchronously. +// If you have methods that may take a long time, other subsequent requests may time out. +func (c *Client) SetBlocking(blocking bool) { + c.blocking = blocking +} + +// Run the client's read loop. +// You must run this method before calling any methods on the server. +func (c *Client) Run() { + c.readLoop() +} + +// DisconnectNotify returns a channel that is closed +// when the client connection has gone away. +func (c *Client) DisconnectNotify() chan struct{} { + return c.disconnect +} + +// Handle registers the handler function for the given method. If a handler already exists for method, Handle panics. +func (c *Client) Handle(method string, handlerFunc interface{}) { + addHandler(c.handlers, method, handlerFunc) +} + +// readLoop reads messages from codec. +// It reads a reqeust or a response to the previous request. +// If the message is request, calls the handler function. +// If the message is response, sends the reply to the associated call. +func (c *Client) readLoop() { + var err error + var req Request + var resp Response + for err == nil { + req = Request{} + resp = Response{} + if err = c.codec.ReadHeader(&req, &resp); err != nil { + break + } + + if req.Method != "" { + // request comes to server + if err = c.readRequest(&req); err != nil { + debugln("rpc2: error reading request:", err.Error()) + } + } else { + // response comes to client + if err = c.readResponse(&resp); err != nil { + debugln("rpc2: error reading response:", err.Error()) + } + } + } + // Terminate pending calls. + c.sending.Lock() + c.mutex.Lock() + c.shutdown = true + closing := c.closing + if err == io.EOF { + if closing { + err = ErrShutdown + } else { + err = io.ErrUnexpectedEOF + } + } + for _, call := range c.pending { + call.Error = err + call.done() + } + c.mutex.Unlock() + c.sending.Unlock() + if err != io.EOF && !closing && !c.server { + debugln("rpc2: client protocol error:", err) + } + close(c.disconnect) + if !closing { + c.codec.Close() + } +} + +func (c *Client) handleRequest(req Request, method *handler, argv reflect.Value) { + // Invoke the method, providing a new value for the reply. + replyv := reflect.New(method.replyType.Elem()) + + returnValues := method.fn.Call([]reflect.Value{reflect.ValueOf(c), argv, replyv}) + + // Do not send response if request is a notification. + if req.Seq == 0 { + return + } + + // The return value for the method is an error. + errInter := returnValues[0].Interface() + errmsg := "" + if errInter != nil { + errmsg = errInter.(error).Error() + } + resp := &Response{ + Seq: req.Seq, + Error: errmsg, + } + if err := c.codec.WriteResponse(resp, replyv.Interface()); err != nil { + debugln("rpc2: error writing response:", err.Error()) + } +} + +func (c *Client) readRequest(req *Request) error { + method, ok := c.handlers[req.Method] + if !ok { + resp := &Response{ + Seq: req.Seq, + Error: "rpc2: can't find method " + req.Method, + } + return c.codec.WriteResponse(resp, resp) + } + + // Decode the argument value. + var argv reflect.Value + argIsValue := false // if true, need to indirect before calling. + if method.argType.Kind() == reflect.Ptr { + argv = reflect.New(method.argType.Elem()) + } else { + argv = reflect.New(method.argType) + argIsValue = true + } + // argv guaranteed to be a pointer now. + if err := c.codec.ReadRequestBody(argv.Interface()); err != nil { + return err + } + if argIsValue { + argv = argv.Elem() + } + + if c.blocking { + c.handleRequest(*req, method, argv) + } else { + go c.handleRequest(*req, method, argv) + } + + return nil +} + +func (c *Client) readResponse(resp *Response) error { + seq := resp.Seq + c.mutex.Lock() + call := c.pending[seq] + delete(c.pending, seq) + c.mutex.Unlock() + + var err error + switch { + case call == nil: + // We've got no pending call. That usually means that + // WriteRequest partially failed, and call was already + // removed; response is a server telling us about an + // error reading request body. We should still attempt + // to read error body, but there's no one to give it to. + err = c.codec.ReadResponseBody(nil) + if err != nil { + err = errors.New("reading error body: " + err.Error()) + } + case resp.Error != "": + // We've got an error response. Give this to the request; + // any subsequent requests will get the ReadResponseBody + // error if there is one. + call.Error = ServerError(resp.Error) + err = c.codec.ReadResponseBody(nil) + if err != nil { + err = errors.New("reading error body: " + err.Error()) + } + call.done() + default: + err = c.codec.ReadResponseBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + + return err +} + +// Close waits for active calls to finish and closes the codec. +func (c *Client) Close() error { + c.mutex.Lock() + if c.shutdown || c.closing { + c.mutex.Unlock() + return ErrShutdown + } + c.closing = true + c.mutex.Unlock() + return c.codec.Close() +} + +// Go invokes the function asynchronously. It returns the Call structure representing +// the invocation. The done channel will signal when the call is complete by returning +// the same Call object. If done is nil, Go will allocate a new channel. +// If non-nil, done must be buffered or Go will deliberately crash. +func (c *Client) Go(method string, args interface{}, reply interface{}, done chan *Call) *Call { + call := new(Call) + call.Method = method + call.Args = args + call.Reply = reply + if done == nil { + done = make(chan *Call, 10) // buffered. + } else { + // If caller passes done != nil, it must arrange that + // done has enough buffer for the number of simultaneous + // RPCs that will be using that channel. If the channel + // is totally unbuffered, it's best not to run at all. + if cap(done) == 0 { + log.Panic("rpc2: done channel is unbuffered") + } + } + call.Done = done + c.send(call) + return call +} + +// CallWithContext invokes the named function, waits for it to complete, and +// returns its error status, or an error from Context timeout. +func (c *Client) CallWithContext(ctx context.Context, method string, args interface{}, reply interface{}) error { + call := c.Go(method, args, reply, make(chan *Call, 1)) + select { + case <-call.Done: + return call.Error + case <-ctx.Done(): + return ctx.Err() + } + return nil +} + +// Call invokes the named function, waits for it to complete, and returns its error status. +func (c *Client) Call(method string, args interface{}, reply interface{}) error { + return c.CallWithContext(context.Background(), method, args, reply) +} + +func (call *Call) done() { + select { + case call.Done <- call: + // ok + default: + // We don't want to block here. It is the caller's responsibility to make + // sure the channel has enough buffer space. See comment in Go(). + debugln("rpc2: discarding Call reply due to insufficient Done chan capacity") + } +} + +// ServerError represents an error that has been returned from +// the remote side of the RPC connection. +type ServerError string + +func (e ServerError) Error() string { + return string(e) +} + +// ErrShutdown is returned when the connection is closing or closed. +var ErrShutdown = errors.New("connection is shut down") + +// Call represents an active RPC. +type Call struct { + Method string // The name of the service and method to call. + Args interface{} // The argument to the function (*struct). + Reply interface{} // The reply from the function (*struct). + Error error // After completion, the error status. + Done chan *Call // Strobes when call is complete. +} + +func (c *Client) send(call *Call) { + c.sending.Lock() + defer c.sending.Unlock() + + // Register this call. + c.mutex.Lock() + if c.shutdown || c.closing { + call.Error = ErrShutdown + c.mutex.Unlock() + call.done() + return + } + seq := c.seq + c.seq++ + c.pending[seq] = call + c.mutex.Unlock() + + // Encode and send the request. + c.request.Seq = seq + c.request.Method = call.Method + err := c.codec.WriteRequest(&c.request, call.Args) + if err != nil { + c.mutex.Lock() + call = c.pending[seq] + delete(c.pending, seq) + c.mutex.Unlock() + if call != nil { + call.Error = err + call.done() + } + } +} + +// Notify sends a request to the receiver but does not wait for a return value. +func (c *Client) Notify(method string, args interface{}) error { + c.sending.Lock() + defer c.sending.Unlock() + + if c.shutdown || c.closing { + return ErrShutdown + } + + c.request.Seq = 0 + c.request.Method = method + return c.codec.WriteRequest(&c.request, args) +} diff --git a/vendor/github.com/cenkalti/rpc2/codec.go b/vendor/github.com/cenkalti/rpc2/codec.go new file mode 100644 index 00000000..b097d9aa --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/codec.go @@ -0,0 +1,125 @@ +package rpc2 + +import ( + "bufio" + "encoding/gob" + "io" + "sync" +) + +// A Codec implements reading and writing of RPC requests and responses. +// The client calls ReadHeader to read a message header. +// The implementation must populate either Request or Response argument. +// Depending on which argument is populated, ReadRequestBody or +// ReadResponseBody is called right after ReadHeader. +// ReadRequestBody and ReadResponseBody may be called with a nil +// argument to force the body to be read and then discarded. +type Codec interface { + // ReadHeader must read a message and populate either the request + // or the response by inspecting the incoming message. + ReadHeader(*Request, *Response) error + + // ReadRequestBody into args argument of handler function. + ReadRequestBody(interface{}) error + + // ReadResponseBody into reply argument of handler function. + ReadResponseBody(interface{}) error + + // WriteRequest must be safe for concurrent use by multiple goroutines. + WriteRequest(*Request, interface{}) error + + // WriteResponse must be safe for concurrent use by multiple goroutines. + WriteResponse(*Response, interface{}) error + + // Close is called when client/server finished with the connection. + Close() error +} + +// Request is a header written before every RPC call. +type Request struct { + Seq uint64 // sequence number chosen by client + Method string +} + +// Response is a header written before every RPC return. +type Response struct { + Seq uint64 // echoes that of the request + Error string // error, if any. +} + +type gobCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder + encBuf *bufio.Writer + mutex sync.Mutex +} + +type message struct { + Seq uint64 + Method string + Error string +} + +// NewGobCodec returns a new rpc2.Codec using gob encoding/decoding on conn. +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &gobCodec{ + rwc: conn, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + encBuf: buf, + } +} + +func (c *gobCodec) ReadHeader(req *Request, resp *Response) error { + var msg message + if err := c.dec.Decode(&msg); err != nil { + return err + } + + if msg.Method != "" { + req.Seq = msg.Seq + req.Method = msg.Method + } else { + resp.Seq = msg.Seq + resp.Error = msg.Error + } + return nil +} + +func (c *gobCodec) ReadRequestBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *gobCodec) ReadResponseBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *gobCodec) WriteRequest(r *Request, body interface{}) (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + if err = c.enc.Encode(r); err != nil { + return + } + if err = c.enc.Encode(body); err != nil { + return + } + return c.encBuf.Flush() +} + +func (c *gobCodec) WriteResponse(r *Response, body interface{}) (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + if err = c.enc.Encode(r); err != nil { + return + } + if err = c.enc.Encode(body); err != nil { + return + } + return c.encBuf.Flush() +} + +func (c *gobCodec) Close() error { + return c.rwc.Close() +} diff --git a/vendor/github.com/cenkalti/rpc2/debug.go b/vendor/github.com/cenkalti/rpc2/debug.go new file mode 100644 index 00000000..ec1b6252 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/debug.go @@ -0,0 +1,12 @@ +package rpc2 + +import "log" + +// DebugLog controls the printing of internal and I/O errors. +var DebugLog = false + +func debugln(v ...interface{}) { + if DebugLog { + log.Println(v...) + } +} diff --git a/vendor/github.com/cenkalti/rpc2/jsonrpc/jsonrpc.go b/vendor/github.com/cenkalti/rpc2/jsonrpc/jsonrpc.go new file mode 100644 index 00000000..87e11688 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/jsonrpc/jsonrpc.go @@ -0,0 +1,226 @@ +// Package jsonrpc implements a JSON-RPC ClientCodec and ServerCodec for the rpc2 package. +// +// Beside struct types, JSONCodec allows using positional arguments. +// Use []interface{} as the type of argument when sending and receiving methods. +// +// Positional arguments example: +// server.Handle("add", func(client *rpc2.Client, args []interface{}, result *float64) error { +// *result = args[0].(float64) + args[1].(float64) +// return nil +// }) +// +// var result float64 +// client.Call("add", []interface{}{1, 2}, &result) +// +package jsonrpc + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "reflect" + "sync" + + "github.com/cenkalti/rpc2" +) + +type jsonCodec struct { + dec *json.Decoder // for reading JSON values + enc *json.Encoder // for writing JSON values + c io.Closer + + // temporary work space + msg message + serverRequest serverRequest + clientResponse clientResponse + + // JSON-RPC clients can use arbitrary json values as request IDs. + // Package rpc expects uint64 request IDs. + // We assign uint64 sequence numbers to incoming requests + // but save the original request ID in the pending map. + // When rpc responds, we use the sequence number in + // the response to find the original request ID. + mutex sync.Mutex // protects seq, pending + pending map[uint64]*json.RawMessage + seq uint64 +} + +// NewJSONCodec returns a new rpc2.Codec using JSON-RPC on conn. +func NewJSONCodec(conn io.ReadWriteCloser) rpc2.Codec { + return &jsonCodec{ + dec: json.NewDecoder(conn), + enc: json.NewEncoder(conn), + c: conn, + pending: make(map[uint64]*json.RawMessage), + } +} + +// serverRequest and clientResponse combined +type message struct { + Method string `json:"method"` + Params *json.RawMessage `json:"params"` + Id *json.RawMessage `json:"id"` + Result *json.RawMessage `json:"result"` + Error interface{} `json:"error"` +} + +// Unmarshal to +type serverRequest struct { + Method string `json:"method"` + Params *json.RawMessage `json:"params"` + Id *json.RawMessage `json:"id"` +} +type clientResponse struct { + Id uint64 `json:"id"` + Result *json.RawMessage `json:"result"` + Error interface{} `json:"error"` +} + +// to Marshal +type serverResponse struct { + Id *json.RawMessage `json:"id"` + Result interface{} `json:"result"` + Error interface{} `json:"error"` +} +type clientRequest struct { + Method string `json:"method"` + Params interface{} `json:"params"` + Id *uint64 `json:"id"` +} + +func (c *jsonCodec) ReadHeader(req *rpc2.Request, resp *rpc2.Response) error { + c.msg = message{} + if err := c.dec.Decode(&c.msg); err != nil { + return err + } + + if c.msg.Method != "" { + // request comes to server + c.serverRequest.Id = c.msg.Id + c.serverRequest.Method = c.msg.Method + c.serverRequest.Params = c.msg.Params + + req.Method = c.serverRequest.Method + + // JSON request id can be any JSON value; + // RPC package expects uint64. Translate to + // internal uint64 and save JSON on the side. + if c.serverRequest.Id == nil { + // Notification + } else { + c.mutex.Lock() + c.seq++ + c.pending[c.seq] = c.serverRequest.Id + c.serverRequest.Id = nil + req.Seq = c.seq + c.mutex.Unlock() + } + } else { + // response comes to client + err := json.Unmarshal(*c.msg.Id, &c.clientResponse.Id) + if err != nil { + return err + } + c.clientResponse.Result = c.msg.Result + c.clientResponse.Error = c.msg.Error + + resp.Error = "" + resp.Seq = c.clientResponse.Id + if c.clientResponse.Error != nil || c.clientResponse.Result == nil { + x, ok := c.clientResponse.Error.(string) + if !ok { + return fmt.Errorf("invalid error %v", c.clientResponse.Error) + } + if x == "" { + x = "unspecified error" + } + resp.Error = x + } + } + return nil +} + +var errMissingParams = errors.New("jsonrpc: request body missing params") + +func (c *jsonCodec) ReadRequestBody(x interface{}) error { + if x == nil { + return nil + } + if c.serverRequest.Params == nil { + return errMissingParams + } + + var err error + + // Check if x points to a slice of any kind + rt := reflect.TypeOf(x) + if rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Slice { + // If it's a slice, unmarshal as is + err = json.Unmarshal(*c.serverRequest.Params, x) + } else { + // Anything else unmarshal into a slice containing x + params := &[]interface{}{x} + err = json.Unmarshal(*c.serverRequest.Params, params) + } + + return err +} + +func (c *jsonCodec) ReadResponseBody(x interface{}) error { + if x == nil { + return nil + } + return json.Unmarshal(*c.clientResponse.Result, x) +} + +func (c *jsonCodec) WriteRequest(r *rpc2.Request, param interface{}) error { + req := &clientRequest{Method: r.Method} + + // Check if param is a slice of any kind + if param != nil && reflect.TypeOf(param).Kind() == reflect.Slice { + // If it's a slice, leave as is + req.Params = param + } else { + // Put anything else into a slice + req.Params = []interface{}{param} + } + + if r.Seq == 0 { + // Notification + req.Id = nil + } else { + seq := r.Seq + req.Id = &seq + } + return c.enc.Encode(req) +} + +var null = json.RawMessage([]byte("null")) + +func (c *jsonCodec) WriteResponse(r *rpc2.Response, x interface{}) error { + c.mutex.Lock() + b, ok := c.pending[r.Seq] + if !ok { + c.mutex.Unlock() + return errors.New("invalid sequence number in response") + } + delete(c.pending, r.Seq) + c.mutex.Unlock() + + if b == nil { + // Invalid request so no id. Use JSON null. + b = &null + } + resp := serverResponse{Id: b} + if r.Error == "" { + resp.Result = x + } else { + resp.Error = r.Error + } + return c.enc.Encode(resp) +} + +func (c *jsonCodec) Close() error { + return c.c.Close() +} diff --git a/vendor/github.com/cenkalti/rpc2/jsonrpc/jsonrpc_test.go b/vendor/github.com/cenkalti/rpc2/jsonrpc/jsonrpc_test.go new file mode 100644 index 00000000..390da362 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/jsonrpc/jsonrpc_test.go @@ -0,0 +1,182 @@ +package jsonrpc + +import ( + "encoding/json" + "fmt" + "net" + "testing" + "time" + + "github.com/cenkalti/rpc2" +) + +const ( + network = "tcp4" + addr = "127.0.0.1:5000" +) + +func TestJSONRPC(t *testing.T) { + type Args struct{ A, B int } + type Reply int + + lis, err := net.Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + srv := rpc2.NewServer() + srv.Handle("add", func(client *rpc2.Client, args *Args, reply *Reply) error { + *reply = Reply(args.A + args.B) + + var rep Reply + err := client.Call("mult", Args{2, 3}, &rep) + if err != nil { + t.Fatal(err) + } + + if rep != 6 { + t.Fatalf("not expected: %d", rep) + } + + return nil + }) + srv.Handle("addPos", func(client *rpc2.Client, args []interface{}, result *float64) error { + *result = args[0].(float64) + args[1].(float64) + return nil + }) + srv.Handle("rawArgs", func(client *rpc2.Client, args []json.RawMessage, reply *[]string) error { + for _, p := range args { + var str string + json.Unmarshal(p, &str) + *reply = append(*reply, str) + } + return nil + }) + srv.Handle("typedArgs", func(client *rpc2.Client, args []int, reply *[]string) error { + for _, p := range args { + *reply = append(*reply, fmt.Sprintf("%d", p)) + } + return nil + }) + srv.Handle("nilArgs", func(client *rpc2.Client, args []interface{}, reply *[]string) error { + for _, v := range args { + if v == nil { + *reply = append(*reply, "nil") + } + } + return nil + }) + number := make(chan int, 1) + srv.Handle("set", func(client *rpc2.Client, i int, _ *struct{}) error { + number <- i + return nil + }) + + go func() { + conn, err := lis.Accept() + if err != nil { + t.Fatal(err) + } + srv.ServeCodec(NewJSONCodec(conn)) + }() + + conn, err := net.Dial(network, addr) + if err != nil { + t.Fatal(err) + } + + clt := rpc2.NewClientWithCodec(NewJSONCodec(conn)) + clt.Handle("mult", func(client *rpc2.Client, args *Args, reply *Reply) error { + *reply = Reply(args.A * args.B) + return nil + }) + go clt.Run() + + // Test Call. + var rep Reply + err = clt.Call("add", Args{1, 2}, &rep) + if err != nil { + t.Fatal(err) + } + if rep != 3 { + t.Fatalf("not expected: %d", rep) + } + + // Test notification. + err = clt.Notify("set", 6) + if err != nil { + t.Fatal(err) + } + select { + case i := <-number: + if i != 6 { + t.Fatalf("unexpected number: %d", i) + } + case <-time.After(time.Second): + t.Fatal("did not get notification") + } + + // Test undefined method. + err = clt.Call("foo", 1, &rep) + if err.Error() != "rpc2: can't find method foo" { + t.Fatal(err) + } + + // Test Positional arguments. + var result float64 + err = clt.Call("addPos", []interface{}{1, 2}, &result) + if err != nil { + t.Fatal(err) + } + if result != 3 { + t.Fatalf("not expected: %f", result) + } + + testArgs := func(expected, reply []string) error { + if len(reply) != len(expected) { + return fmt.Errorf("incorrect reply length: %d", len(reply)) + } + for i := range expected { + if reply[i] != expected[i] { + return fmt.Errorf("not expected reply[%d]: %s", i, reply[i]) + } + } + return nil + } + + // Test raw arguments (partial unmarshal) + var reply []string + var expected []string = []string{"arg1", "arg2"} + rawArgs := json.RawMessage(`["arg1", "arg2"]`) + err = clt.Call("rawArgs", rawArgs, &reply) + if err != nil { + t.Fatal(err) + } + + if err = testArgs(expected, reply); err != nil { + t.Fatal(err) + } + + // Test typed arguments + reply = []string{} + expected = []string{"1", "2"} + typedArgs := []int{1, 2} + err = clt.Call("typedArgs", typedArgs, &reply) + if err != nil { + t.Fatal(err) + } + if err = testArgs(expected, reply); err != nil { + t.Fatal(err) + } + + // Test nil args + reply = []string{} + expected = []string{"nil"} + err = clt.Call("nilArgs", nil, &reply) + if err != nil { + t.Fatal(err) + } + if err = testArgs(expected, reply); err != nil { + t.Fatal(err) + } +} diff --git a/vendor/github.com/cenkalti/rpc2/rpc2_test.go b/vendor/github.com/cenkalti/rpc2/rpc2_test.go new file mode 100644 index 00000000..5ce97e63 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/rpc2_test.go @@ -0,0 +1,98 @@ +package rpc2 + +import ( + "net" + "testing" + "time" +) + +const ( + network = "tcp4" + addr = "127.0.0.1:5000" +) + +func TestTCPGOB(t *testing.T) { + type Args struct{ A, B int } + type Reply int + + lis, err := net.Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + srv := NewServer() + srv.Handle("add", func(client *Client, args *Args, reply *Reply) error { + *reply = Reply(args.A + args.B) + + var rep Reply + err := client.Call("mult", Args{2, 3}, &rep) + if err != nil { + t.Fatal(err) + } + + if rep != 6 { + t.Fatalf("not expected: %d", rep) + } + + return nil + }) + number := make(chan int, 1) + srv.Handle("set", func(client *Client, i int, _ *struct{}) error { + number <- i + return nil + }) + go srv.Accept(lis) + + conn, err := net.Dial(network, addr) + if err != nil { + t.Fatal(err) + } + + clt := NewClient(conn) + clt.Handle("mult", func(client *Client, args *Args, reply *Reply) error { + *reply = Reply(args.A * args.B) + return nil + }) + go clt.Run() + defer clt.Close() + + // Test Call. + var rep Reply + err = clt.Call("add", Args{1, 2}, &rep) + if err != nil { + t.Fatal(err) + } + if rep != 3 { + t.Fatalf("not expected: %d", rep) + } + + // Test notification. + err = clt.Notify("set", 6) + if err != nil { + t.Fatal(err) + } + select { + case i := <-number: + if i != 6 { + t.Fatalf("unexpected number: %d", i) + } + case <-time.After(time.Second): + t.Fatal("did not get notification") + } + + // Test blocked request + clt.SetBlocking(true) + err = clt.Call("add", Args{1, 2}, &rep) + if err != nil { + t.Fatal(err) + } + if rep != 3 { + t.Fatalf("not expected: %d", rep) + } + + // Test undefined method. + err = clt.Call("foo", 1, &rep) + if err.Error() != "rpc2: can't find method foo" { + t.Fatal(err) + } +} diff --git a/vendor/github.com/cenkalti/rpc2/server.go b/vendor/github.com/cenkalti/rpc2/server.go new file mode 100644 index 00000000..2a5be7ed --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/server.go @@ -0,0 +1,181 @@ +package rpc2 + +import ( + "io" + "log" + "net" + "reflect" + "unicode" + "unicode/utf8" + + "github.com/cenkalti/hub" +) + +// Precompute the reflect type for error. Can't use error directly +// because Typeof takes an empty interface value. This is annoying. +var typeOfError = reflect.TypeOf((*error)(nil)).Elem() +var typeOfClient = reflect.TypeOf((*Client)(nil)) + +const ( + clientConnected hub.Kind = iota + clientDisconnected +) + +// Server responds to RPC requests made by Client. +type Server struct { + handlers map[string]*handler + eventHub *hub.Hub +} + +type handler struct { + fn reflect.Value + argType reflect.Type + replyType reflect.Type +} + +type connectionEvent struct { + Client *Client +} + +type disconnectionEvent struct { + Client *Client +} + +func (connectionEvent) Kind() hub.Kind { return clientConnected } +func (disconnectionEvent) Kind() hub.Kind { return clientDisconnected } + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{ + handlers: make(map[string]*handler), + eventHub: &hub.Hub{}, + } +} + +// Handle registers the handler function for the given method. If a handler already exists for method, Handle panics. +func (s *Server) Handle(method string, handlerFunc interface{}) { + addHandler(s.handlers, method, handlerFunc) +} + +func addHandler(handlers map[string]*handler, mname string, handlerFunc interface{}) { + if _, ok := handlers[mname]; ok { + panic("rpc2: multiple registrations for " + mname) + } + + method := reflect.ValueOf(handlerFunc) + mtype := method.Type() + // Method needs three ins: *client, *args, *reply. + if mtype.NumIn() != 3 { + log.Panicln("method", mname, "has wrong number of ins:", mtype.NumIn()) + } + // First arg must be a pointer to rpc2.Client. + clientType := mtype.In(0) + if clientType.Kind() != reflect.Ptr { + log.Panicln("method", mname, "client type not a pointer:", clientType) + } + if clientType != typeOfClient { + log.Panicln("method", mname, "first argument", clientType.String(), "not *rpc2.Client") + } + // Second arg need not be a pointer. + argType := mtype.In(1) + if !isExportedOrBuiltinType(argType) { + log.Panicln(mname, "argument type not exported:", argType) + } + // Third arg must be a pointer. + replyType := mtype.In(2) + if replyType.Kind() != reflect.Ptr { + log.Panicln("method", mname, "reply type not a pointer:", replyType) + } + // Reply type must be exported. + if !isExportedOrBuiltinType(replyType) { + log.Panicln("method", mname, "reply type not exported:", replyType) + } + // Method needs one out. + if mtype.NumOut() != 1 { + log.Panicln("method", mname, "has wrong number of outs:", mtype.NumOut()) + } + // The return type of the method must be error. + if returnType := mtype.Out(0); returnType != typeOfError { + log.Panicln("method", mname, "returns", returnType.String(), "not error") + } + handlers[mname] = &handler{ + fn: method, + argType: argType, + replyType: replyType, + } +} + +// Is this type exported or a builtin? +func isExportedOrBuiltinType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + // PkgPath will be non-empty even for an exported type, + // so we need to check the type name as well. + return isExported(t.Name()) || t.PkgPath() == "" +} + +// Is this an exported - upper case - name? +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// OnConnect registers a function to run when a client connects. +func (s *Server) OnConnect(f func(*Client)) { + s.eventHub.Subscribe(clientConnected, func(e hub.Event) { + go f(e.(connectionEvent).Client) + }) +} + +// OnDisconnect registers a function to run when a client disconnects. +func (s *Server) OnDisconnect(f func(*Client)) { + s.eventHub.Subscribe(clientDisconnected, func(e hub.Event) { + go f(e.(disconnectionEvent).Client) + }) +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. Accept blocks; the caller typically +// invokes it in a go statement. +func (s *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Print("rpc.Serve: accept:", err.Error()) + return + } + go s.ServeConn(conn) + } +} + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +func (s *Server) ServeConn(conn io.ReadWriteCloser) { + s.ServeCodec(NewGobCodec(conn)) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func (s *Server) ServeCodec(codec Codec) { + s.ServeCodecWithState(codec, NewState()) +} + +// ServeCodecWithState is like ServeCodec but also gives the ability to +// associate a state variable with the client that persists across RPC calls. +func (s *Server) ServeCodecWithState(codec Codec, state *State) { + defer codec.Close() + + // Client also handles the incoming connections. + c := NewClientWithCodec(codec) + c.server = true + c.handlers = s.handlers + c.State = state + + s.eventHub.Publish(connectionEvent{c}) + c.Run() + s.eventHub.Publish(disconnectionEvent{c}) +} diff --git a/vendor/github.com/cenkalti/rpc2/state.go b/vendor/github.com/cenkalti/rpc2/state.go new file mode 100644 index 00000000..7a4f23e6 --- /dev/null +++ b/vendor/github.com/cenkalti/rpc2/state.go @@ -0,0 +1,25 @@ +package rpc2 + +import "sync" + +type State struct { + store map[string]interface{} + m sync.RWMutex +} + +func NewState() *State { + return &State{store: make(map[string]interface{})} +} + +func (s *State) Get(key string) (value interface{}, ok bool) { + s.m.RLock() + value, ok = s.store[key] + s.m.RUnlock() + return +} + +func (s *State) Set(key string, value interface{}) { + s.m.Lock() + s.store[key] = value + s.m.Unlock() +} diff --git a/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS b/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS new file mode 100644 index 00000000..2b16e997 --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS @@ -0,0 +1,9 @@ +This is a list of people who have contributed code to go-cache. They, or their +employers, are the copyright holders of the contributed code. Contributed code +is subject to the license restrictions listed in LICENSE (as they were when the +code was contributed.) + +Dustin Sallings +Jason Mooberry +Sergey Shepelev +Alex Edwards diff --git a/vendor/github.com/patrickmn/go-cache/LICENSE b/vendor/github.com/patrickmn/go-cache/LICENSE new file mode 100644 index 00000000..f49969d7 --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2012-2019 Patrick Mylund Nielsen and the go-cache contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/patrickmn/go-cache/README.md b/vendor/github.com/patrickmn/go-cache/README.md new file mode 100644 index 00000000..c5789cc6 --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/README.md @@ -0,0 +1,83 @@ +# go-cache + +go-cache is an in-memory key:value store/cache similar to memcached that is +suitable for applications running on a single machine. Its major advantage is +that, being essentially a thread-safe `map[string]interface{}` with expiration +times, it doesn't need to serialize or transmit its contents over the network. + +Any object can be stored, for a given duration or forever, and the cache can be +safely used by multiple goroutines. + +Although go-cache isn't meant to be used as a persistent datastore, the entire +cache can be saved to and loaded from a file (using `c.Items()` to retrieve the +items map to serialize, and `NewFrom()` to create a cache from a deserialized +one) to recover from downtime quickly. (See the docs for `NewFrom()` for caveats.) + +### Installation + +`go get github.com/patrickmn/go-cache` + +### Usage + +```go +import ( + "fmt" + "github.com/patrickmn/go-cache" + "time" +) + +func main() { + // Create a cache with a default expiration time of 5 minutes, and which + // purges expired items every 10 minutes + c := cache.New(5*time.Minute, 10*time.Minute) + + // Set the value of the key "foo" to "bar", with the default expiration time + c.Set("foo", "bar", cache.DefaultExpiration) + + // Set the value of the key "baz" to 42, with no expiration time + // (the item won't be removed until it is re-set, or removed using + // c.Delete("baz") + c.Set("baz", 42, cache.NoExpiration) + + // Get the string associated with the key "foo" from the cache + foo, found := c.Get("foo") + if found { + fmt.Println(foo) + } + + // Since Go is statically typed, and cache values can be anything, type + // assertion is needed when values are being passed to functions that don't + // take arbitrary types, (i.e. interface{}). The simplest way to do this for + // values which will only be used once--e.g. for passing to another + // function--is: + foo, found := c.Get("foo") + if found { + MyFunction(foo.(string)) + } + + // This gets tedious if the value is used several times in the same function. + // You might do either of the following instead: + if x, found := c.Get("foo"); found { + foo := x.(string) + // ... + } + // or + var foo string + if x, found := c.Get("foo"); found { + foo = x.(string) + } + // ... + // foo can then be passed around freely as a string + + // Want performance? Store pointers! + c.Set("foo", &MyStruct, cache.DefaultExpiration) + if x, found := c.Get("foo"); found { + foo := x.(*MyStruct) + // ... + } +} +``` + +### Reference + +`godoc` or [http://godoc.org/github.com/patrickmn/go-cache](http://godoc.org/github.com/patrickmn/go-cache) diff --git a/vendor/github.com/patrickmn/go-cache/cache.go b/vendor/github.com/patrickmn/go-cache/cache.go new file mode 100644 index 00000000..db88d2f2 --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/cache.go @@ -0,0 +1,1161 @@ +package cache + +import ( + "encoding/gob" + "fmt" + "io" + "os" + "runtime" + "sync" + "time" +) + +type Item struct { + Object interface{} + Expiration int64 +} + +// Returns true if the item has expired. +func (item Item) Expired() bool { + if item.Expiration == 0 { + return false + } + return time.Now().UnixNano() > item.Expiration +} + +const ( + // For use with functions that take an expiration time. + NoExpiration time.Duration = -1 + // For use with functions that take an expiration time. Equivalent to + // passing in the same expiration duration as was given to New() or + // NewFrom() when the cache was created (e.g. 5 minutes.) + DefaultExpiration time.Duration = 0 +) + +type Cache struct { + *cache + // If this is confusing, see the comment at the bottom of New() +} + +type cache struct { + defaultExpiration time.Duration + items map[string]Item + mu sync.RWMutex + onEvicted func(string, interface{}) + janitor *janitor +} + +// Add an item to the cache, replacing any existing item. If the duration is 0 +// (DefaultExpiration), the cache's default expiration time is used. If it is -1 +// (NoExpiration), the item never expires. +func (c *cache) Set(k string, x interface{}, d time.Duration) { + // "Inlining" of set + var e int64 + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + e = time.Now().Add(d).UnixNano() + } + c.mu.Lock() + c.items[k] = Item{ + Object: x, + Expiration: e, + } + // TODO: Calls to mu.Unlock are currently not deferred because defer + // adds ~200 ns (as of go1.) + c.mu.Unlock() +} + +func (c *cache) set(k string, x interface{}, d time.Duration) { + var e int64 + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + e = time.Now().Add(d).UnixNano() + } + c.items[k] = Item{ + Object: x, + Expiration: e, + } +} + +// Add an item to the cache, replacing any existing item, using the default +// expiration. +func (c *cache) SetDefault(k string, x interface{}) { + c.Set(k, x, DefaultExpiration) +} + +// Add an item to the cache only if an item doesn't already exist for the given +// key, or if the existing item has expired. Returns an error otherwise. +func (c *cache) Add(k string, x interface{}, d time.Duration) error { + c.mu.Lock() + _, found := c.get(k) + if found { + c.mu.Unlock() + return fmt.Errorf("Item %s already exists", k) + } + c.set(k, x, d) + c.mu.Unlock() + return nil +} + +// Set a new value for the cache key only if it already exists, and the existing +// item hasn't expired. Returns an error otherwise. +func (c *cache) Replace(k string, x interface{}, d time.Duration) error { + c.mu.Lock() + _, found := c.get(k) + if !found { + c.mu.Unlock() + return fmt.Errorf("Item %s doesn't exist", k) + } + c.set(k, x, d) + c.mu.Unlock() + return nil +} + +// Get an item from the cache. Returns the item or nil, and a bool indicating +// whether the key was found. +func (c *cache) Get(k string) (interface{}, bool) { + c.mu.RLock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, false + } + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, false + } + } + c.mu.RUnlock() + return item.Object, true +} + +// GetWithExpiration returns an item and its expiration time from the cache. +// It returns the item or nil, the expiration time if one is set (if the item +// never expires a zero value for time.Time is returned), and a bool indicating +// whether the key was found. +func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) { + c.mu.RLock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, time.Time{}, false + } + + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, time.Time{}, false + } + + // Return the item and the expiration time + c.mu.RUnlock() + return item.Object, time.Unix(0, item.Expiration), true + } + + // If expiration <= 0 (i.e. no expiration time set) then return the item + // and a zeroed time.Time + c.mu.RUnlock() + return item.Object, time.Time{}, true +} + +func (c *cache) get(k string) (interface{}, bool) { + item, found := c.items[k] + if !found { + return nil, false + } + // "Inlining" of Expired + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + return nil, false + } + } + return item.Object, true +} + +// Increment an item of type int, int8, int16, int32, int64, uintptr, uint, +// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the +// item's value is not an integer, if it was not found, or if it is not +// possible to increment it by n. To retrieve the incremented value, use one +// of the specialized methods, e.g. IncrementInt64. +func (c *cache) Increment(k string, n int64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case int: + v.Object = v.Object.(int) + int(n) + case int8: + v.Object = v.Object.(int8) + int8(n) + case int16: + v.Object = v.Object.(int16) + int16(n) + case int32: + v.Object = v.Object.(int32) + int32(n) + case int64: + v.Object = v.Object.(int64) + n + case uint: + v.Object = v.Object.(uint) + uint(n) + case uintptr: + v.Object = v.Object.(uintptr) + uintptr(n) + case uint8: + v.Object = v.Object.(uint8) + uint8(n) + case uint16: + v.Object = v.Object.(uint16) + uint16(n) + case uint32: + v.Object = v.Object.(uint32) + uint32(n) + case uint64: + v.Object = v.Object.(uint64) + uint64(n) + case float32: + v.Object = v.Object.(float32) + float32(n) + case float64: + v.Object = v.Object.(float64) + float64(n) + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s is not an integer", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Increment an item of type float32 or float64 by n. Returns an error if the +// item's value is not floating point, if it was not found, or if it is not +// possible to increment it by n. Pass a negative number to decrement the +// value. To retrieve the incremented value, use one of the specialized methods, +// e.g. IncrementFloat64. +func (c *cache) IncrementFloat(k string, n float64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case float32: + v.Object = v.Object.(float32) + float32(n) + case float64: + v.Object = v.Object.(float64) + n + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s does not have type float32 or float64", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Increment an item of type int by n. Returns an error if the item's value is +// not an int, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt(k string, n int) (int, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int8 by n. Returns an error if the item's value is +// not an int8, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt8(k string, n int8) (int8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int8", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int16 by n. Returns an error if the item's value is +// not an int16, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt16(k string, n int16) (int16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int16", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int32 by n. Returns an error if the item's value is +// not an int32, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt32(k string, n int32) (int32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int64 by n. Returns an error if the item's value is +// not an int64, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt64(k string, n int64) (int64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint by n. Returns an error if the item's value is +// not an uint, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementUint(k string, n uint) (uint, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uintptr by n. Returns an error if the item's value +// is not an uintptr, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUintptr(k string, n uintptr) (uintptr, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uintptr) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uintptr", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint8 by n. Returns an error if the item's value +// is not an uint8, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint8(k string, n uint8) (uint8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint8", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint16 by n. Returns an error if the item's value +// is not an uint16, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint16(k string, n uint16) (uint16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint16", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint32 by n. Returns an error if the item's value +// is not an uint32, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint32(k string, n uint32) (uint32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint64 by n. Returns an error if the item's value +// is not an uint64, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint64(k string, n uint64) (uint64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type float32 by n. Returns an error if the item's value +// is not an float32, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementFloat32(k string, n float32) (float32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type float64 by n. Returns an error if the item's value +// is not an float64, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementFloat64(k string, n float64) (float64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int, int8, int16, int32, int64, uintptr, uint, +// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the +// item's value is not an integer, if it was not found, or if it is not +// possible to decrement it by n. To retrieve the decremented value, use one +// of the specialized methods, e.g. DecrementInt64. +func (c *cache) Decrement(k string, n int64) error { + // TODO: Implement Increment and Decrement more cleanly. + // (Cannot do Increment(k, n*-1) for uints.) + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item not found") + } + switch v.Object.(type) { + case int: + v.Object = v.Object.(int) - int(n) + case int8: + v.Object = v.Object.(int8) - int8(n) + case int16: + v.Object = v.Object.(int16) - int16(n) + case int32: + v.Object = v.Object.(int32) - int32(n) + case int64: + v.Object = v.Object.(int64) - n + case uint: + v.Object = v.Object.(uint) - uint(n) + case uintptr: + v.Object = v.Object.(uintptr) - uintptr(n) + case uint8: + v.Object = v.Object.(uint8) - uint8(n) + case uint16: + v.Object = v.Object.(uint16) - uint16(n) + case uint32: + v.Object = v.Object.(uint32) - uint32(n) + case uint64: + v.Object = v.Object.(uint64) - uint64(n) + case float32: + v.Object = v.Object.(float32) - float32(n) + case float64: + v.Object = v.Object.(float64) - float64(n) + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s is not an integer", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Decrement an item of type float32 or float64 by n. Returns an error if the +// item's value is not floating point, if it was not found, or if it is not +// possible to decrement it by n. Pass a negative number to decrement the +// value. To retrieve the decremented value, use one of the specialized methods, +// e.g. DecrementFloat64. +func (c *cache) DecrementFloat(k string, n float64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case float32: + v.Object = v.Object.(float32) - float32(n) + case float64: + v.Object = v.Object.(float64) - n + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s does not have type float32 or float64", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Decrement an item of type int by n. Returns an error if the item's value is +// not an int, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt(k string, n int) (int, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int8 by n. Returns an error if the item's value is +// not an int8, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt8(k string, n int8) (int8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int8", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int16 by n. Returns an error if the item's value is +// not an int16, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt16(k string, n int16) (int16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int16", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int32 by n. Returns an error if the item's value is +// not an int32, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt32(k string, n int32) (int32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int64 by n. Returns an error if the item's value is +// not an int64, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt64(k string, n int64) (int64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint by n. Returns an error if the item's value is +// not an uint, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementUint(k string, n uint) (uint, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uintptr by n. Returns an error if the item's value +// is not an uintptr, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUintptr(k string, n uintptr) (uintptr, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uintptr) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uintptr", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint8 by n. Returns an error if the item's value is +// not an uint8, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementUint8(k string, n uint8) (uint8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint8", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint16 by n. Returns an error if the item's value +// is not an uint16, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint16(k string, n uint16) (uint16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint16", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint32 by n. Returns an error if the item's value +// is not an uint32, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint32(k string, n uint32) (uint32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint64 by n. Returns an error if the item's value +// is not an uint64, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint64(k string, n uint64) (uint64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type float32 by n. Returns an error if the item's value +// is not an float32, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementFloat32(k string, n float32) (float32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type float64 by n. Returns an error if the item's value +// is not an float64, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementFloat64(k string, n float64) (float64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Delete an item from the cache. Does nothing if the key is not in the cache. +func (c *cache) Delete(k string) { + c.mu.Lock() + v, evicted := c.delete(k) + c.mu.Unlock() + if evicted { + c.onEvicted(k, v) + } +} + +func (c *cache) delete(k string) (interface{}, bool) { + if c.onEvicted != nil { + if v, found := c.items[k]; found { + delete(c.items, k) + return v.Object, true + } + } + delete(c.items, k) + return nil, false +} + +type keyAndValue struct { + key string + value interface{} +} + +// Delete all expired items from the cache. +func (c *cache) DeleteExpired() { + var evictedItems []keyAndValue + now := time.Now().UnixNano() + c.mu.Lock() + for k, v := range c.items { + // "Inlining" of expired + if v.Expiration > 0 && now > v.Expiration { + ov, evicted := c.delete(k) + if evicted { + evictedItems = append(evictedItems, keyAndValue{k, ov}) + } + } + } + c.mu.Unlock() + for _, v := range evictedItems { + c.onEvicted(v.key, v.value) + } +} + +// Sets an (optional) function that is called with the key and value when an +// item is evicted from the cache. (Including when it is deleted manually, but +// not when it is overwritten.) Set to nil to disable. +func (c *cache) OnEvicted(f func(string, interface{})) { + c.mu.Lock() + c.onEvicted = f + c.mu.Unlock() +} + +// Write the cache's items (using Gob) to an io.Writer. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) Save(w io.Writer) (err error) { + enc := gob.NewEncoder(w) + defer func() { + if x := recover(); x != nil { + err = fmt.Errorf("Error registering item types with Gob library") + } + }() + c.mu.RLock() + defer c.mu.RUnlock() + for _, v := range c.items { + gob.Register(v.Object) + } + err = enc.Encode(&c.items) + return +} + +// Save the cache's items to the given filename, creating the file if it +// doesn't exist, and overwriting it if it does. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) SaveFile(fname string) error { + fp, err := os.Create(fname) + if err != nil { + return err + } + err = c.Save(fp) + if err != nil { + fp.Close() + return err + } + return fp.Close() +} + +// Add (Gob-serialized) cache items from an io.Reader, excluding any items with +// keys that already exist (and haven't expired) in the current cache. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) Load(r io.Reader) error { + dec := gob.NewDecoder(r) + items := map[string]Item{} + err := dec.Decode(&items) + if err == nil { + c.mu.Lock() + defer c.mu.Unlock() + for k, v := range items { + ov, found := c.items[k] + if !found || ov.Expired() { + c.items[k] = v + } + } + } + return err +} + +// Load and add cache items from the given filename, excluding any items with +// keys that already exist in the current cache. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) LoadFile(fname string) error { + fp, err := os.Open(fname) + if err != nil { + return err + } + err = c.Load(fp) + if err != nil { + fp.Close() + return err + } + return fp.Close() +} + +// Copies all unexpired items in the cache into a new map and returns it. +func (c *cache) Items() map[string]Item { + c.mu.RLock() + defer c.mu.RUnlock() + m := make(map[string]Item, len(c.items)) + now := time.Now().UnixNano() + for k, v := range c.items { + // "Inlining" of Expired + if v.Expiration > 0 { + if now > v.Expiration { + continue + } + } + m[k] = v + } + return m +} + +// Returns the number of items in the cache. This may include items that have +// expired, but have not yet been cleaned up. +func (c *cache) ItemCount() int { + c.mu.RLock() + n := len(c.items) + c.mu.RUnlock() + return n +} + +// Delete all items from the cache. +func (c *cache) Flush() { + c.mu.Lock() + c.items = map[string]Item{} + c.mu.Unlock() +} + +type janitor struct { + Interval time.Duration + stop chan bool +} + +func (j *janitor) Run(c *cache) { + ticker := time.NewTicker(j.Interval) + for { + select { + case <-ticker.C: + c.DeleteExpired() + case <-j.stop: + ticker.Stop() + return + } + } +} + +func stopJanitor(c *Cache) { + c.janitor.stop <- true +} + +func runJanitor(c *cache, ci time.Duration) { + j := &janitor{ + Interval: ci, + stop: make(chan bool), + } + c.janitor = j + go j.Run(c) +} + +func newCache(de time.Duration, m map[string]Item) *cache { + if de == 0 { + de = -1 + } + c := &cache{ + defaultExpiration: de, + items: m, + } + return c +} + +func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]Item) *Cache { + c := newCache(de, m) + // This trick ensures that the janitor goroutine (which--granted it + // was enabled--is running DeleteExpired on c forever) does not keep + // the returned C object from being garbage collected. When it is + // garbage collected, the finalizer stops the janitor goroutine, after + // which c can be collected. + C := &Cache{c} + if ci > 0 { + runJanitor(c, ci) + runtime.SetFinalizer(C, stopJanitor) + } + return C +} + +// Return a new cache with a given default expiration duration and cleanup +// interval. If the expiration duration is less than one (or NoExpiration), +// the items in the cache never expire (by default), and must be deleted +// manually. If the cleanup interval is less than one, expired items are not +// deleted from the cache before calling c.DeleteExpired(). +func New(defaultExpiration, cleanupInterval time.Duration) *Cache { + items := make(map[string]Item) + return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) +} + +// Return a new cache with a given default expiration duration and cleanup +// interval. If the expiration duration is less than one (or NoExpiration), +// the items in the cache never expire (by default), and must be deleted +// manually. If the cleanup interval is less than one, expired items are not +// deleted from the cache before calling c.DeleteExpired(). +// +// NewFrom() also accepts an items map which will serve as the underlying map +// for the cache. This is useful for starting from a deserialized cache +// (serialized using e.g. gob.Encode() on c.Items()), or passing in e.g. +// make(map[string]Item, 500) to improve startup performance when the cache +// is expected to reach a certain minimum size. +// +// Only the cache's methods synchronize access to this map, so it is not +// recommended to keep any references to the map around after creating a cache. +// If need be, the map can be accessed at a later point using c.Items() (subject +// to the same caveat.) +// +// Note regarding serialization: When using e.g. gob, make sure to +// gob.Register() the individual types stored in the cache before encoding a +// map retrieved with c.Items(), and to register those same types before +// decoding a blob containing an items map. +func NewFrom(defaultExpiration, cleanupInterval time.Duration, items map[string]Item) *Cache { + return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) +} diff --git a/vendor/github.com/patrickmn/go-cache/cache_test.go b/vendor/github.com/patrickmn/go-cache/cache_test.go new file mode 100644 index 00000000..de3e9d6b --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/cache_test.go @@ -0,0 +1,1771 @@ +package cache + +import ( + "bytes" + "io/ioutil" + "runtime" + "strconv" + "sync" + "testing" + "time" +) + +type TestStruct struct { + Num int + Children []*TestStruct +} + +func TestCache(t *testing.T) { + tc := New(DefaultExpiration, 0) + + a, found := tc.Get("a") + if found || a != nil { + t.Error("Getting A found value that shouldn't exist:", a) + } + + b, found := tc.Get("b") + if found || b != nil { + t.Error("Getting B found value that shouldn't exist:", b) + } + + c, found := tc.Get("c") + if found || c != nil { + t.Error("Getting C found value that shouldn't exist:", c) + } + + tc.Set("a", 1, DefaultExpiration) + tc.Set("b", "b", DefaultExpiration) + tc.Set("c", 3.5, DefaultExpiration) + + x, found := tc.Get("a") + if !found { + t.Error("a was not found while getting a2") + } + if x == nil { + t.Error("x for a is nil") + } else if a2 := x.(int); a2+2 != 3 { + t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2) + } + + x, found = tc.Get("b") + if !found { + t.Error("b was not found while getting b2") + } + if x == nil { + t.Error("x for b is nil") + } else if b2 := x.(string); b2+"B" != "bB" { + t.Error("b2 (which should be b) plus B does not equal bB; value:", b2) + } + + x, found = tc.Get("c") + if !found { + t.Error("c was not found while getting c2") + } + if x == nil { + t.Error("x for c is nil") + } else if c2 := x.(float64); c2+1.2 != 4.7 { + t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2) + } +} + +func TestCacheTimes(t *testing.T) { + var found bool + + tc := New(50*time.Millisecond, 1*time.Millisecond) + tc.Set("a", 1, DefaultExpiration) + tc.Set("b", 2, NoExpiration) + tc.Set("c", 3, 20*time.Millisecond) + tc.Set("d", 4, 70*time.Millisecond) + + <-time.After(25 * time.Millisecond) + _, found = tc.Get("c") + if found { + t.Error("Found c when it should have been automatically deleted") + } + + <-time.After(30 * time.Millisecond) + _, found = tc.Get("a") + if found { + t.Error("Found a when it should have been automatically deleted") + } + + _, found = tc.Get("b") + if !found { + t.Error("Did not find b even though it was set to never expire") + } + + _, found = tc.Get("d") + if !found { + t.Error("Did not find d even though it was set to expire later than the default") + } + + <-time.After(20 * time.Millisecond) + _, found = tc.Get("d") + if found { + t.Error("Found d when it should have been automatically deleted (later than the default)") + } +} + +func TestNewFrom(t *testing.T) { + m := map[string]Item{ + "a": Item{ + Object: 1, + Expiration: 0, + }, + "b": Item{ + Object: 2, + Expiration: 0, + }, + } + tc := NewFrom(DefaultExpiration, 0, m) + a, found := tc.Get("a") + if !found { + t.Fatal("Did not find a") + } + if a.(int) != 1 { + t.Fatal("a is not 1") + } + b, found := tc.Get("b") + if !found { + t.Fatal("Did not find b") + } + if b.(int) != 2 { + t.Fatal("b is not 2") + } +} + +func TestStorePointerToStruct(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("foo", &TestStruct{Num: 1}, DefaultExpiration) + x, found := tc.Get("foo") + if !found { + t.Fatal("*TestStruct was not found for foo") + } + foo := x.(*TestStruct) + foo.Num++ + + y, found := tc.Get("foo") + if !found { + t.Fatal("*TestStruct was not found for foo (second time)") + } + bar := y.(*TestStruct) + if bar.Num != 2 { + t.Fatal("TestStruct.Num is not 2") + } +} + +func TestIncrementWithInt(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint", 1, DefaultExpiration) + err := tc.Increment("tint", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint") + if !found { + t.Error("tint was not found") + } + if x.(int) != 3 { + t.Error("tint is not 3:", x) + } +} + +func TestIncrementWithInt8(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint8", int8(1), DefaultExpiration) + err := tc.Increment("tint8", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint8") + if !found { + t.Error("tint8 was not found") + } + if x.(int8) != 3 { + t.Error("tint8 is not 3:", x) + } +} + +func TestIncrementWithInt16(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint16", int16(1), DefaultExpiration) + err := tc.Increment("tint16", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint16") + if !found { + t.Error("tint16 was not found") + } + if x.(int16) != 3 { + t.Error("tint16 is not 3:", x) + } +} + +func TestIncrementWithInt32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint32", int32(1), DefaultExpiration) + err := tc.Increment("tint32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint32") + if !found { + t.Error("tint32 was not found") + } + if x.(int32) != 3 { + t.Error("tint32 is not 3:", x) + } +} + +func TestIncrementWithInt64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint64", int64(1), DefaultExpiration) + err := tc.Increment("tint64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tint64") + if !found { + t.Error("tint64 was not found") + } + if x.(int64) != 3 { + t.Error("tint64 is not 3:", x) + } +} + +func TestIncrementWithUint(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint", uint(1), DefaultExpiration) + err := tc.Increment("tuint", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tuint") + if !found { + t.Error("tuint was not found") + } + if x.(uint) != 3 { + t.Error("tuint is not 3:", x) + } +} + +func TestIncrementWithUintptr(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuintptr", uintptr(1), DefaultExpiration) + err := tc.Increment("tuintptr", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + + x, found := tc.Get("tuintptr") + if !found { + t.Error("tuintptr was not found") + } + if x.(uintptr) != 3 { + t.Error("tuintptr is not 3:", x) + } +} + +func TestIncrementWithUint8(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint8", uint8(1), DefaultExpiration) + err := tc.Increment("tuint8", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tuint8") + if !found { + t.Error("tuint8 was not found") + } + if x.(uint8) != 3 { + t.Error("tuint8 is not 3:", x) + } +} + +func TestIncrementWithUint16(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint16", uint16(1), DefaultExpiration) + err := tc.Increment("tuint16", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + + x, found := tc.Get("tuint16") + if !found { + t.Error("tuint16 was not found") + } + if x.(uint16) != 3 { + t.Error("tuint16 is not 3:", x) + } +} + +func TestIncrementWithUint32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint32", uint32(1), DefaultExpiration) + err := tc.Increment("tuint32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("tuint32") + if !found { + t.Error("tuint32 was not found") + } + if x.(uint32) != 3 { + t.Error("tuint32 is not 3:", x) + } +} + +func TestIncrementWithUint64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint64", uint64(1), DefaultExpiration) + err := tc.Increment("tuint64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + + x, found := tc.Get("tuint64") + if !found { + t.Error("tuint64 was not found") + } + if x.(uint64) != 3 { + t.Error("tuint64 is not 3:", x) + } +} + +func TestIncrementWithFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float32", float32(1.5), DefaultExpiration) + err := tc.Increment("float32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3.5:", x) + } +} + +func TestIncrementWithFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float64", float64(1.5), DefaultExpiration) + err := tc.Increment("float64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3.5:", x) + } +} + +func TestIncrementFloatWithFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float32", float32(1.5), DefaultExpiration) + err := tc.IncrementFloat("float32", 2) + if err != nil { + t.Error("Error incrementfloating:", err) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3.5:", x) + } +} + +func TestIncrementFloatWithFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float64", float64(1.5), DefaultExpiration) + err := tc.IncrementFloat("float64", 2) + if err != nil { + t.Error("Error incrementfloating:", err) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3.5:", x) + } +} + +func TestDecrementWithInt(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int", int(5), DefaultExpiration) + err := tc.Decrement("int", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int") + if !found { + t.Error("int was not found") + } + if x.(int) != 3 { + t.Error("int is not 3:", x) + } +} + +func TestDecrementWithInt8(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int8", int8(5), DefaultExpiration) + err := tc.Decrement("int8", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int8") + if !found { + t.Error("int8 was not found") + } + if x.(int8) != 3 { + t.Error("int8 is not 3:", x) + } +} + +func TestDecrementWithInt16(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int16", int16(5), DefaultExpiration) + err := tc.Decrement("int16", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int16") + if !found { + t.Error("int16 was not found") + } + if x.(int16) != 3 { + t.Error("int16 is not 3:", x) + } +} + +func TestDecrementWithInt32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int32", int32(5), DefaultExpiration) + err := tc.Decrement("int32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int32") + if !found { + t.Error("int32 was not found") + } + if x.(int32) != 3 { + t.Error("int32 is not 3:", x) + } +} + +func TestDecrementWithInt64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int64", int64(5), DefaultExpiration) + err := tc.Decrement("int64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("int64") + if !found { + t.Error("int64 was not found") + } + if x.(int64) != 3 { + t.Error("int64 is not 3:", x) + } +} + +func TestDecrementWithUint(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint", uint(5), DefaultExpiration) + err := tc.Decrement("uint", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint") + if !found { + t.Error("uint was not found") + } + if x.(uint) != 3 { + t.Error("uint is not 3:", x) + } +} + +func TestDecrementWithUintptr(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uintptr", uintptr(5), DefaultExpiration) + err := tc.Decrement("uintptr", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uintptr") + if !found { + t.Error("uintptr was not found") + } + if x.(uintptr) != 3 { + t.Error("uintptr is not 3:", x) + } +} + +func TestDecrementWithUint8(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint8", uint8(5), DefaultExpiration) + err := tc.Decrement("uint8", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint8") + if !found { + t.Error("uint8 was not found") + } + if x.(uint8) != 3 { + t.Error("uint8 is not 3:", x) + } +} + +func TestDecrementWithUint16(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint16", uint16(5), DefaultExpiration) + err := tc.Decrement("uint16", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint16") + if !found { + t.Error("uint16 was not found") + } + if x.(uint16) != 3 { + t.Error("uint16 is not 3:", x) + } +} + +func TestDecrementWithUint32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint32", uint32(5), DefaultExpiration) + err := tc.Decrement("uint32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint32") + if !found { + t.Error("uint32 was not found") + } + if x.(uint32) != 3 { + t.Error("uint32 is not 3:", x) + } +} + +func TestDecrementWithUint64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint64", uint64(5), DefaultExpiration) + err := tc.Decrement("uint64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("uint64") + if !found { + t.Error("uint64 was not found") + } + if x.(uint64) != 3 { + t.Error("uint64 is not 3:", x) + } +} + +func TestDecrementWithFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float32", float32(5.5), DefaultExpiration) + err := tc.Decrement("float32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3:", x) + } +} + +func TestDecrementWithFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float64", float64(5.5), DefaultExpiration) + err := tc.Decrement("float64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3:", x) + } +} + +func TestDecrementFloatWithFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float32", float32(5.5), DefaultExpiration) + err := tc.DecrementFloat("float32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3:", x) + } +} + +func TestDecrementFloatWithFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float64", float64(5.5), DefaultExpiration) + err := tc.DecrementFloat("float64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3:", x) + } +} + +func TestIncrementInt(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint", 1, DefaultExpiration) + n, err := tc.IncrementInt("tint", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint") + if !found { + t.Error("tint was not found") + } + if x.(int) != 3 { + t.Error("tint is not 3:", x) + } +} + +func TestIncrementInt8(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint8", int8(1), DefaultExpiration) + n, err := tc.IncrementInt8("tint8", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint8") + if !found { + t.Error("tint8 was not found") + } + if x.(int8) != 3 { + t.Error("tint8 is not 3:", x) + } +} + +func TestIncrementInt16(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint16", int16(1), DefaultExpiration) + n, err := tc.IncrementInt16("tint16", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint16") + if !found { + t.Error("tint16 was not found") + } + if x.(int16) != 3 { + t.Error("tint16 is not 3:", x) + } +} + +func TestIncrementInt32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint32", int32(1), DefaultExpiration) + n, err := tc.IncrementInt32("tint32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint32") + if !found { + t.Error("tint32 was not found") + } + if x.(int32) != 3 { + t.Error("tint32 is not 3:", x) + } +} + +func TestIncrementInt64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tint64", int64(1), DefaultExpiration) + n, err := tc.IncrementInt64("tint64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tint64") + if !found { + t.Error("tint64 was not found") + } + if x.(int64) != 3 { + t.Error("tint64 is not 3:", x) + } +} + +func TestIncrementUint(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint", uint(1), DefaultExpiration) + n, err := tc.IncrementUint("tuint", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint") + if !found { + t.Error("tuint was not found") + } + if x.(uint) != 3 { + t.Error("tuint is not 3:", x) + } +} + +func TestIncrementUintptr(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuintptr", uintptr(1), DefaultExpiration) + n, err := tc.IncrementUintptr("tuintptr", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuintptr") + if !found { + t.Error("tuintptr was not found") + } + if x.(uintptr) != 3 { + t.Error("tuintptr is not 3:", x) + } +} + +func TestIncrementUint8(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint8", uint8(1), DefaultExpiration) + n, err := tc.IncrementUint8("tuint8", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint8") + if !found { + t.Error("tuint8 was not found") + } + if x.(uint8) != 3 { + t.Error("tuint8 is not 3:", x) + } +} + +func TestIncrementUint16(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint16", uint16(1), DefaultExpiration) + n, err := tc.IncrementUint16("tuint16", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint16") + if !found { + t.Error("tuint16 was not found") + } + if x.(uint16) != 3 { + t.Error("tuint16 is not 3:", x) + } +} + +func TestIncrementUint32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint32", uint32(1), DefaultExpiration) + n, err := tc.IncrementUint32("tuint32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint32") + if !found { + t.Error("tuint32 was not found") + } + if x.(uint32) != 3 { + t.Error("tuint32 is not 3:", x) + } +} + +func TestIncrementUint64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("tuint64", uint64(1), DefaultExpiration) + n, err := tc.IncrementUint64("tuint64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("tuint64") + if !found { + t.Error("tuint64 was not found") + } + if x.(uint64) != 3 { + t.Error("tuint64 is not 3:", x) + } +} + +func TestIncrementFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float32", float32(1.5), DefaultExpiration) + n, err := tc.IncrementFloat32("float32", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3.5 { + t.Error("Returned number is not 3.5:", n) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3.5 { + t.Error("float32 is not 3.5:", x) + } +} + +func TestIncrementFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float64", float64(1.5), DefaultExpiration) + n, err := tc.IncrementFloat64("float64", 2) + if err != nil { + t.Error("Error incrementing:", err) + } + if n != 3.5 { + t.Error("Returned number is not 3.5:", n) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3.5 { + t.Error("float64 is not 3.5:", x) + } +} + +func TestDecrementInt8(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int8", int8(5), DefaultExpiration) + n, err := tc.DecrementInt8("int8", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("int8") + if !found { + t.Error("int8 was not found") + } + if x.(int8) != 3 { + t.Error("int8 is not 3:", x) + } +} + +func TestDecrementInt16(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int16", int16(5), DefaultExpiration) + n, err := tc.DecrementInt16("int16", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("int16") + if !found { + t.Error("int16 was not found") + } + if x.(int16) != 3 { + t.Error("int16 is not 3:", x) + } +} + +func TestDecrementInt32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int32", int32(5), DefaultExpiration) + n, err := tc.DecrementInt32("int32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("int32") + if !found { + t.Error("int32 was not found") + } + if x.(int32) != 3 { + t.Error("int32 is not 3:", x) + } +} + +func TestDecrementInt64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int64", int64(5), DefaultExpiration) + n, err := tc.DecrementInt64("int64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("int64") + if !found { + t.Error("int64 was not found") + } + if x.(int64) != 3 { + t.Error("int64 is not 3:", x) + } +} + +func TestDecrementUint(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint", uint(5), DefaultExpiration) + n, err := tc.DecrementUint("uint", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint") + if !found { + t.Error("uint was not found") + } + if x.(uint) != 3 { + t.Error("uint is not 3:", x) + } +} + +func TestDecrementUintptr(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uintptr", uintptr(5), DefaultExpiration) + n, err := tc.DecrementUintptr("uintptr", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uintptr") + if !found { + t.Error("uintptr was not found") + } + if x.(uintptr) != 3 { + t.Error("uintptr is not 3:", x) + } +} + +func TestDecrementUint8(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint8", uint8(5), DefaultExpiration) + n, err := tc.DecrementUint8("uint8", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint8") + if !found { + t.Error("uint8 was not found") + } + if x.(uint8) != 3 { + t.Error("uint8 is not 3:", x) + } +} + +func TestDecrementUint16(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint16", uint16(5), DefaultExpiration) + n, err := tc.DecrementUint16("uint16", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint16") + if !found { + t.Error("uint16 was not found") + } + if x.(uint16) != 3 { + t.Error("uint16 is not 3:", x) + } +} + +func TestDecrementUint32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint32", uint32(5), DefaultExpiration) + n, err := tc.DecrementUint32("uint32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint32") + if !found { + t.Error("uint32 was not found") + } + if x.(uint32) != 3 { + t.Error("uint32 is not 3:", x) + } +} + +func TestDecrementUint64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint64", uint64(5), DefaultExpiration) + n, err := tc.DecrementUint64("uint64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("uint64") + if !found { + t.Error("uint64 was not found") + } + if x.(uint64) != 3 { + t.Error("uint64 is not 3:", x) + } +} + +func TestDecrementFloat32(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float32", float32(5), DefaultExpiration) + n, err := tc.DecrementFloat32("float32", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("float32") + if !found { + t.Error("float32 was not found") + } + if x.(float32) != 3 { + t.Error("float32 is not 3:", x) + } +} + +func TestDecrementFloat64(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("float64", float64(5), DefaultExpiration) + n, err := tc.DecrementFloat64("float64", 2) + if err != nil { + t.Error("Error decrementing:", err) + } + if n != 3 { + t.Error("Returned number is not 3:", n) + } + x, found := tc.Get("float64") + if !found { + t.Error("float64 was not found") + } + if x.(float64) != 3 { + t.Error("float64 is not 3:", x) + } +} + +func TestAdd(t *testing.T) { + tc := New(DefaultExpiration, 0) + err := tc.Add("foo", "bar", DefaultExpiration) + if err != nil { + t.Error("Couldn't add foo even though it shouldn't exist") + } + err = tc.Add("foo", "baz", DefaultExpiration) + if err == nil { + t.Error("Successfully added another foo when it should have returned an error") + } +} + +func TestReplace(t *testing.T) { + tc := New(DefaultExpiration, 0) + err := tc.Replace("foo", "bar", DefaultExpiration) + if err == nil { + t.Error("Replaced foo when it shouldn't exist") + } + tc.Set("foo", "bar", DefaultExpiration) + err = tc.Replace("foo", "bar", DefaultExpiration) + if err != nil { + t.Error("Couldn't replace existing key foo") + } +} + +func TestDelete(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("foo", "bar", DefaultExpiration) + tc.Delete("foo") + x, found := tc.Get("foo") + if found { + t.Error("foo was found, but it should have been deleted") + } + if x != nil { + t.Error("x is not nil:", x) + } +} + +func TestItemCount(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("foo", "1", DefaultExpiration) + tc.Set("bar", "2", DefaultExpiration) + tc.Set("baz", "3", DefaultExpiration) + if n := tc.ItemCount(); n != 3 { + t.Errorf("Item count is not 3: %d", n) + } +} + +func TestFlush(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("foo", "bar", DefaultExpiration) + tc.Set("baz", "yes", DefaultExpiration) + tc.Flush() + x, found := tc.Get("foo") + if found { + t.Error("foo was found, but it should have been deleted") + } + if x != nil { + t.Error("x is not nil:", x) + } + x, found = tc.Get("baz") + if found { + t.Error("baz was found, but it should have been deleted") + } + if x != nil { + t.Error("x is not nil:", x) + } +} + +func TestIncrementOverflowInt(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("int8", int8(127), DefaultExpiration) + err := tc.Increment("int8", 1) + if err != nil { + t.Error("Error incrementing int8:", err) + } + x, _ := tc.Get("int8") + int8 := x.(int8) + if int8 != -128 { + t.Error("int8 did not overflow as expected; value:", int8) + } + +} + +func TestIncrementOverflowUint(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint8", uint8(255), DefaultExpiration) + err := tc.Increment("uint8", 1) + if err != nil { + t.Error("Error incrementing int8:", err) + } + x, _ := tc.Get("uint8") + uint8 := x.(uint8) + if uint8 != 0 { + t.Error("uint8 did not overflow as expected; value:", uint8) + } +} + +func TestDecrementUnderflowUint(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("uint8", uint8(0), DefaultExpiration) + err := tc.Decrement("uint8", 1) + if err != nil { + t.Error("Error decrementing int8:", err) + } + x, _ := tc.Get("uint8") + uint8 := x.(uint8) + if uint8 != 255 { + t.Error("uint8 did not underflow as expected; value:", uint8) + } +} + +func TestOnEvicted(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Set("foo", 3, DefaultExpiration) + if tc.onEvicted != nil { + t.Fatal("tc.onEvicted is not nil") + } + works := false + tc.OnEvicted(func(k string, v interface{}) { + if k == "foo" && v.(int) == 3 { + works = true + } + tc.Set("bar", 4, DefaultExpiration) + }) + tc.Delete("foo") + x, _ := tc.Get("bar") + if !works { + t.Error("works bool not true") + } + if x.(int) != 4 { + t.Error("bar was not 4") + } +} + +func TestCacheSerialization(t *testing.T) { + tc := New(DefaultExpiration, 0) + testFillAndSerialize(t, tc) + + // Check if gob.Register behaves properly even after multiple gob.Register + // on c.Items (many of which will be the same type) + testFillAndSerialize(t, tc) +} + +func testFillAndSerialize(t *testing.T, tc *Cache) { + tc.Set("a", "a", DefaultExpiration) + tc.Set("b", "b", DefaultExpiration) + tc.Set("c", "c", DefaultExpiration) + tc.Set("expired", "foo", 1*time.Millisecond) + tc.Set("*struct", &TestStruct{Num: 1}, DefaultExpiration) + tc.Set("[]struct", []TestStruct{ + {Num: 2}, + {Num: 3}, + }, DefaultExpiration) + tc.Set("[]*struct", []*TestStruct{ + &TestStruct{Num: 4}, + &TestStruct{Num: 5}, + }, DefaultExpiration) + tc.Set("structception", &TestStruct{ + Num: 42, + Children: []*TestStruct{ + &TestStruct{Num: 6174}, + &TestStruct{Num: 4716}, + }, + }, DefaultExpiration) + + fp := &bytes.Buffer{} + err := tc.Save(fp) + if err != nil { + t.Fatal("Couldn't save cache to fp:", err) + } + + oc := New(DefaultExpiration, 0) + err = oc.Load(fp) + if err != nil { + t.Fatal("Couldn't load cache from fp:", err) + } + + a, found := oc.Get("a") + if !found { + t.Error("a was not found") + } + if a.(string) != "a" { + t.Error("a is not a") + } + + b, found := oc.Get("b") + if !found { + t.Error("b was not found") + } + if b.(string) != "b" { + t.Error("b is not b") + } + + c, found := oc.Get("c") + if !found { + t.Error("c was not found") + } + if c.(string) != "c" { + t.Error("c is not c") + } + + <-time.After(5 * time.Millisecond) + _, found = oc.Get("expired") + if found { + t.Error("expired was found") + } + + s1, found := oc.Get("*struct") + if !found { + t.Error("*struct was not found") + } + if s1.(*TestStruct).Num != 1 { + t.Error("*struct.Num is not 1") + } + + s2, found := oc.Get("[]struct") + if !found { + t.Error("[]struct was not found") + } + s2r := s2.([]TestStruct) + if len(s2r) != 2 { + t.Error("Length of s2r is not 2") + } + if s2r[0].Num != 2 { + t.Error("s2r[0].Num is not 2") + } + if s2r[1].Num != 3 { + t.Error("s2r[1].Num is not 3") + } + + s3, found := oc.get("[]*struct") + if !found { + t.Error("[]*struct was not found") + } + s3r := s3.([]*TestStruct) + if len(s3r) != 2 { + t.Error("Length of s3r is not 2") + } + if s3r[0].Num != 4 { + t.Error("s3r[0].Num is not 4") + } + if s3r[1].Num != 5 { + t.Error("s3r[1].Num is not 5") + } + + s4, found := oc.get("structception") + if !found { + t.Error("structception was not found") + } + s4r := s4.(*TestStruct) + if len(s4r.Children) != 2 { + t.Error("Length of s4r.Children is not 2") + } + if s4r.Children[0].Num != 6174 { + t.Error("s4r.Children[0].Num is not 6174") + } + if s4r.Children[1].Num != 4716 { + t.Error("s4r.Children[1].Num is not 4716") + } +} + +func TestFileSerialization(t *testing.T) { + tc := New(DefaultExpiration, 0) + tc.Add("a", "a", DefaultExpiration) + tc.Add("b", "b", DefaultExpiration) + f, err := ioutil.TempFile("", "go-cache-cache.dat") + if err != nil { + t.Fatal("Couldn't create cache file:", err) + } + fname := f.Name() + f.Close() + tc.SaveFile(fname) + + oc := New(DefaultExpiration, 0) + oc.Add("a", "aa", 0) // this should not be overwritten + err = oc.LoadFile(fname) + if err != nil { + t.Error(err) + } + a, found := oc.Get("a") + if !found { + t.Error("a was not found") + } + astr := a.(string) + if astr != "aa" { + if astr == "a" { + t.Error("a was overwritten") + } else { + t.Error("a is not aa") + } + } + b, found := oc.Get("b") + if !found { + t.Error("b was not found") + } + if b.(string) != "b" { + t.Error("b is not b") + } +} + +func TestSerializeUnserializable(t *testing.T) { + tc := New(DefaultExpiration, 0) + ch := make(chan bool, 1) + ch <- true + tc.Set("chan", ch, DefaultExpiration) + fp := &bytes.Buffer{} + err := tc.Save(fp) // this should fail gracefully + if err.Error() != "gob NewTypeObject can't handle type: chan bool" { + t.Error("Error from Save was not gob NewTypeObject can't handle type chan bool:", err) + } +} + +func BenchmarkCacheGetExpiring(b *testing.B) { + benchmarkCacheGet(b, 5*time.Minute) +} + +func BenchmarkCacheGetNotExpiring(b *testing.B) { + benchmarkCacheGet(b, NoExpiration) +} + +func benchmarkCacheGet(b *testing.B, exp time.Duration) { + b.StopTimer() + tc := New(exp, 0) + tc.Set("foo", "bar", DefaultExpiration) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.Get("foo") + } +} + +func BenchmarkRWMutexMapGet(b *testing.B) { + b.StopTimer() + m := map[string]string{ + "foo": "bar", + } + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.RLock() + _, _ = m["foo"] + mu.RUnlock() + } +} + +func BenchmarkRWMutexInterfaceMapGetStruct(b *testing.B) { + b.StopTimer() + s := struct{ name string }{name: "foo"} + m := map[interface{}]string{ + s: "bar", + } + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.RLock() + _, _ = m[s] + mu.RUnlock() + } +} + +func BenchmarkRWMutexInterfaceMapGetString(b *testing.B) { + b.StopTimer() + m := map[interface{}]string{ + "foo": "bar", + } + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.RLock() + _, _ = m["foo"] + mu.RUnlock() + } +} + +func BenchmarkCacheGetConcurrentExpiring(b *testing.B) { + benchmarkCacheGetConcurrent(b, 5*time.Minute) +} + +func BenchmarkCacheGetConcurrentNotExpiring(b *testing.B) { + benchmarkCacheGetConcurrent(b, NoExpiration) +} + +func benchmarkCacheGetConcurrent(b *testing.B, exp time.Duration) { + b.StopTimer() + tc := New(exp, 0) + tc.Set("foo", "bar", DefaultExpiration) + wg := new(sync.WaitGroup) + workers := runtime.NumCPU() + each := b.N / workers + wg.Add(workers) + b.StartTimer() + for i := 0; i < workers; i++ { + go func() { + for j := 0; j < each; j++ { + tc.Get("foo") + } + wg.Done() + }() + } + wg.Wait() +} + +func BenchmarkRWMutexMapGetConcurrent(b *testing.B) { + b.StopTimer() + m := map[string]string{ + "foo": "bar", + } + mu := sync.RWMutex{} + wg := new(sync.WaitGroup) + workers := runtime.NumCPU() + each := b.N / workers + wg.Add(workers) + b.StartTimer() + for i := 0; i < workers; i++ { + go func() { + for j := 0; j < each; j++ { + mu.RLock() + _, _ = m["foo"] + mu.RUnlock() + } + wg.Done() + }() + } + wg.Wait() +} + +func BenchmarkCacheGetManyConcurrentExpiring(b *testing.B) { + benchmarkCacheGetManyConcurrent(b, 5*time.Minute) +} + +func BenchmarkCacheGetManyConcurrentNotExpiring(b *testing.B) { + benchmarkCacheGetManyConcurrent(b, NoExpiration) +} + +func benchmarkCacheGetManyConcurrent(b *testing.B, exp time.Duration) { + // This is the same as BenchmarkCacheGetConcurrent, but its result + // can be compared against BenchmarkShardedCacheGetManyConcurrent + // in sharded_test.go. + b.StopTimer() + n := 10000 + tc := New(exp, 0) + keys := make([]string, n) + for i := 0; i < n; i++ { + k := "foo" + strconv.Itoa(i) + keys[i] = k + tc.Set(k, "bar", DefaultExpiration) + } + each := b.N / n + wg := new(sync.WaitGroup) + wg.Add(n) + for _, v := range keys { + go func(k string) { + for j := 0; j < each; j++ { + tc.Get(k) + } + wg.Done() + }(v) + } + b.StartTimer() + wg.Wait() +} + +func BenchmarkCacheSetExpiring(b *testing.B) { + benchmarkCacheSet(b, 5*time.Minute) +} + +func BenchmarkCacheSetNotExpiring(b *testing.B) { + benchmarkCacheSet(b, NoExpiration) +} + +func benchmarkCacheSet(b *testing.B, exp time.Duration) { + b.StopTimer() + tc := New(exp, 0) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.Set("foo", "bar", DefaultExpiration) + } +} + +func BenchmarkRWMutexMapSet(b *testing.B) { + b.StopTimer() + m := map[string]string{} + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + m["foo"] = "bar" + mu.Unlock() + } +} + +func BenchmarkCacheSetDelete(b *testing.B) { + b.StopTimer() + tc := New(DefaultExpiration, 0) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.Set("foo", "bar", DefaultExpiration) + tc.Delete("foo") + } +} + +func BenchmarkRWMutexMapSetDelete(b *testing.B) { + b.StopTimer() + m := map[string]string{} + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + m["foo"] = "bar" + mu.Unlock() + mu.Lock() + delete(m, "foo") + mu.Unlock() + } +} + +func BenchmarkCacheSetDeleteSingleLock(b *testing.B) { + b.StopTimer() + tc := New(DefaultExpiration, 0) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.mu.Lock() + tc.set("foo", "bar", DefaultExpiration) + tc.delete("foo") + tc.mu.Unlock() + } +} + +func BenchmarkRWMutexMapSetDeleteSingleLock(b *testing.B) { + b.StopTimer() + m := map[string]string{} + mu := sync.RWMutex{} + b.StartTimer() + for i := 0; i < b.N; i++ { + mu.Lock() + m["foo"] = "bar" + delete(m, "foo") + mu.Unlock() + } +} + +func BenchmarkIncrementInt(b *testing.B) { + b.StopTimer() + tc := New(DefaultExpiration, 0) + tc.Set("foo", 0, DefaultExpiration) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.IncrementInt("foo", 1) + } +} + +func BenchmarkDeleteExpiredLoop(b *testing.B) { + b.StopTimer() + tc := New(5*time.Minute, 0) + tc.mu.Lock() + for i := 0; i < 100000; i++ { + tc.set(strconv.Itoa(i), "bar", DefaultExpiration) + } + tc.mu.Unlock() + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.DeleteExpired() + } +} + +func TestGetWithExpiration(t *testing.T) { + tc := New(DefaultExpiration, 0) + + a, expiration, found := tc.GetWithExpiration("a") + if found || a != nil || !expiration.IsZero() { + t.Error("Getting A found value that shouldn't exist:", a) + } + + b, expiration, found := tc.GetWithExpiration("b") + if found || b != nil || !expiration.IsZero() { + t.Error("Getting B found value that shouldn't exist:", b) + } + + c, expiration, found := tc.GetWithExpiration("c") + if found || c != nil || !expiration.IsZero() { + t.Error("Getting C found value that shouldn't exist:", c) + } + + tc.Set("a", 1, DefaultExpiration) + tc.Set("b", "b", DefaultExpiration) + tc.Set("c", 3.5, DefaultExpiration) + tc.Set("d", 1, NoExpiration) + tc.Set("e", 1, 50*time.Millisecond) + + x, expiration, found := tc.GetWithExpiration("a") + if !found { + t.Error("a was not found while getting a2") + } + if x == nil { + t.Error("x for a is nil") + } else if a2 := x.(int); a2+2 != 3 { + t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2) + } + if !expiration.IsZero() { + t.Error("expiration for a is not a zeroed time") + } + + x, expiration, found = tc.GetWithExpiration("b") + if !found { + t.Error("b was not found while getting b2") + } + if x == nil { + t.Error("x for b is nil") + } else if b2 := x.(string); b2+"B" != "bB" { + t.Error("b2 (which should be b) plus B does not equal bB; value:", b2) + } + if !expiration.IsZero() { + t.Error("expiration for b is not a zeroed time") + } + + x, expiration, found = tc.GetWithExpiration("c") + if !found { + t.Error("c was not found while getting c2") + } + if x == nil { + t.Error("x for c is nil") + } else if c2 := x.(float64); c2+1.2 != 4.7 { + t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2) + } + if !expiration.IsZero() { + t.Error("expiration for c is not a zeroed time") + } + + x, expiration, found = tc.GetWithExpiration("d") + if !found { + t.Error("d was not found while getting d2") + } + if x == nil { + t.Error("x for d is nil") + } else if d2 := x.(int); d2+2 != 3 { + t.Error("d (which should be 1) plus 2 does not equal 3; value:", d2) + } + if !expiration.IsZero() { + t.Error("expiration for d is not a zeroed time") + } + + x, expiration, found = tc.GetWithExpiration("e") + if !found { + t.Error("e was not found while getting e2") + } + if x == nil { + t.Error("x for e is nil") + } else if e2 := x.(int); e2+2 != 3 { + t.Error("e (which should be 1) plus 2 does not equal 3; value:", e2) + } + if expiration.UnixNano() != tc.items["e"].Expiration { + t.Error("expiration for e is not the correct time") + } + if expiration.UnixNano() < time.Now().UnixNano() { + t.Error("expiration for e is in the past") + } +} diff --git a/vendor/github.com/patrickmn/go-cache/sharded.go b/vendor/github.com/patrickmn/go-cache/sharded.go new file mode 100644 index 00000000..bcc0538b --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/sharded.go @@ -0,0 +1,192 @@ +package cache + +import ( + "crypto/rand" + "math" + "math/big" + insecurerand "math/rand" + "os" + "runtime" + "time" +) + +// This is an experimental and unexported (for now) attempt at making a cache +// with better algorithmic complexity than the standard one, namely by +// preventing write locks of the entire cache when an item is added. As of the +// time of writing, the overhead of selecting buckets results in cache +// operations being about twice as slow as for the standard cache with small +// total cache sizes, and faster for larger ones. +// +// See cache_test.go for a few benchmarks. + +type unexportedShardedCache struct { + *shardedCache +} + +type shardedCache struct { + seed uint32 + m uint32 + cs []*cache + janitor *shardedJanitor +} + +// djb2 with better shuffling. 5x faster than FNV with the hash.Hash overhead. +func djb33(seed uint32, k string) uint32 { + var ( + l = uint32(len(k)) + d = 5381 + seed + l + i = uint32(0) + ) + // Why is all this 5x faster than a for loop? + if l >= 4 { + for i < l-4 { + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + d = (d * 33) ^ uint32(k[i+2]) + d = (d * 33) ^ uint32(k[i+3]) + i += 4 + } + } + switch l - i { + case 1: + case 2: + d = (d * 33) ^ uint32(k[i]) + case 3: + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + case 4: + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + d = (d * 33) ^ uint32(k[i+2]) + } + return d ^ (d >> 16) +} + +func (sc *shardedCache) bucket(k string) *cache { + return sc.cs[djb33(sc.seed, k)%sc.m] +} + +func (sc *shardedCache) Set(k string, x interface{}, d time.Duration) { + sc.bucket(k).Set(k, x, d) +} + +func (sc *shardedCache) Add(k string, x interface{}, d time.Duration) error { + return sc.bucket(k).Add(k, x, d) +} + +func (sc *shardedCache) Replace(k string, x interface{}, d time.Duration) error { + return sc.bucket(k).Replace(k, x, d) +} + +func (sc *shardedCache) Get(k string) (interface{}, bool) { + return sc.bucket(k).Get(k) +} + +func (sc *shardedCache) Increment(k string, n int64) error { + return sc.bucket(k).Increment(k, n) +} + +func (sc *shardedCache) IncrementFloat(k string, n float64) error { + return sc.bucket(k).IncrementFloat(k, n) +} + +func (sc *shardedCache) Decrement(k string, n int64) error { + return sc.bucket(k).Decrement(k, n) +} + +func (sc *shardedCache) Delete(k string) { + sc.bucket(k).Delete(k) +} + +func (sc *shardedCache) DeleteExpired() { + for _, v := range sc.cs { + v.DeleteExpired() + } +} + +// Returns the items in the cache. This may include items that have expired, +// but have not yet been cleaned up. If this is significant, the Expiration +// fields of the items should be checked. Note that explicit synchronization +// is needed to use a cache and its corresponding Items() return values at +// the same time, as the maps are shared. +func (sc *shardedCache) Items() []map[string]Item { + res := make([]map[string]Item, len(sc.cs)) + for i, v := range sc.cs { + res[i] = v.Items() + } + return res +} + +func (sc *shardedCache) Flush() { + for _, v := range sc.cs { + v.Flush() + } +} + +type shardedJanitor struct { + Interval time.Duration + stop chan bool +} + +func (j *shardedJanitor) Run(sc *shardedCache) { + j.stop = make(chan bool) + tick := time.Tick(j.Interval) + for { + select { + case <-tick: + sc.DeleteExpired() + case <-j.stop: + return + } + } +} + +func stopShardedJanitor(sc *unexportedShardedCache) { + sc.janitor.stop <- true +} + +func runShardedJanitor(sc *shardedCache, ci time.Duration) { + j := &shardedJanitor{ + Interval: ci, + } + sc.janitor = j + go j.Run(sc) +} + +func newShardedCache(n int, de time.Duration) *shardedCache { + max := big.NewInt(0).SetUint64(uint64(math.MaxUint32)) + rnd, err := rand.Int(rand.Reader, max) + var seed uint32 + if err != nil { + os.Stderr.Write([]byte("WARNING: go-cache's newShardedCache failed to read from the system CSPRNG (/dev/urandom or equivalent.) Your system's security may be compromised. Continuing with an insecure seed.\n")) + seed = insecurerand.Uint32() + } else { + seed = uint32(rnd.Uint64()) + } + sc := &shardedCache{ + seed: seed, + m: uint32(n), + cs: make([]*cache, n), + } + for i := 0; i < n; i++ { + c := &cache{ + defaultExpiration: de, + items: map[string]Item{}, + } + sc.cs[i] = c + } + return sc +} + +func unexportedNewSharded(defaultExpiration, cleanupInterval time.Duration, shards int) *unexportedShardedCache { + if defaultExpiration == 0 { + defaultExpiration = -1 + } + sc := newShardedCache(shards, defaultExpiration) + SC := &unexportedShardedCache{sc} + if cleanupInterval > 0 { + runShardedJanitor(sc, cleanupInterval) + runtime.SetFinalizer(SC, stopShardedJanitor) + } + return SC +} diff --git a/vendor/github.com/patrickmn/go-cache/sharded_test.go b/vendor/github.com/patrickmn/go-cache/sharded_test.go new file mode 100644 index 00000000..84220add --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/sharded_test.go @@ -0,0 +1,85 @@ +package cache + +import ( + "strconv" + "sync" + "testing" + "time" +) + +// func TestDjb33(t *testing.T) { +// } + +var shardedKeys = []string{ + "f", + "fo", + "foo", + "barf", + "barfo", + "foobar", + "bazbarf", + "bazbarfo", + "bazbarfoo", + "foobarbazq", + "foobarbazqu", + "foobarbazquu", + "foobarbazquux", +} + +func TestShardedCache(t *testing.T) { + tc := unexportedNewSharded(DefaultExpiration, 0, 13) + for _, v := range shardedKeys { + tc.Set(v, "value", DefaultExpiration) + } +} + +func BenchmarkShardedCacheGetExpiring(b *testing.B) { + benchmarkShardedCacheGet(b, 5*time.Minute) +} + +func BenchmarkShardedCacheGetNotExpiring(b *testing.B) { + benchmarkShardedCacheGet(b, NoExpiration) +} + +func benchmarkShardedCacheGet(b *testing.B, exp time.Duration) { + b.StopTimer() + tc := unexportedNewSharded(exp, 0, 10) + tc.Set("foobarba", "zquux", DefaultExpiration) + b.StartTimer() + for i := 0; i < b.N; i++ { + tc.Get("foobarba") + } +} + +func BenchmarkShardedCacheGetManyConcurrentExpiring(b *testing.B) { + benchmarkShardedCacheGetManyConcurrent(b, 5*time.Minute) +} + +func BenchmarkShardedCacheGetManyConcurrentNotExpiring(b *testing.B) { + benchmarkShardedCacheGetManyConcurrent(b, NoExpiration) +} + +func benchmarkShardedCacheGetManyConcurrent(b *testing.B, exp time.Duration) { + b.StopTimer() + n := 10000 + tsc := unexportedNewSharded(exp, 0, 20) + keys := make([]string, n) + for i := 0; i < n; i++ { + k := "foo" + strconv.Itoa(i) + keys[i] = k + tsc.Set(k, "bar", DefaultExpiration) + } + each := b.N / n + wg := new(sync.WaitGroup) + wg.Add(n) + for _, v := range keys { + go func(k string) { + for j := 0; j < each; j++ { + tsc.Get(k) + } + wg.Done() + }(v) + } + b.StartTimer() + wg.Wait() +} diff --git a/vendor/github.com/robfig/cron/v3/.gitignore b/vendor/github.com/robfig/cron/v3/.gitignore new file mode 100644 index 00000000..00268614 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/vendor/github.com/robfig/cron/v3/.travis.yml b/vendor/github.com/robfig/cron/v3/.travis.yml new file mode 100644 index 00000000..4f2ee4d9 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/.travis.yml @@ -0,0 +1 @@ +language: go diff --git a/vendor/github.com/robfig/cron/v3/LICENSE b/vendor/github.com/robfig/cron/v3/LICENSE new file mode 100644 index 00000000..3a0f627f --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/LICENSE @@ -0,0 +1,21 @@ +Copyright (C) 2012 Rob Figueiredo +All Rights Reserved. + +MIT LICENSE + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/robfig/cron/v3/README.md b/vendor/github.com/robfig/cron/v3/README.md new file mode 100644 index 00000000..38c4d8a0 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/README.md @@ -0,0 +1,125 @@ +[![GoDoc](http://godoc.org/github.com/robfig/cron?status.png)](http://godoc.org/github.com/robfig/cron) +[![Build Status](https://travis-ci.org/robfig/cron.svg?branch=master)](https://travis-ci.org/robfig/cron) + +# cron + +Cron V3 has been released! + +To download the specific tagged release, run: +```bash +go get github.com/robfig/cron/v3@v3.0.0 +``` +Import it in your program as: +```go +import "github.com/robfig/cron/v3" +``` +It requires Go 1.11 or later due to usage of Go Modules. + +Refer to the documentation here: +http://godoc.org/github.com/robfig/cron + +The rest of this document describes the the advances in v3 and a list of +breaking changes for users that wish to upgrade from an earlier version. + +## Upgrading to v3 (June 2019) + +cron v3 is a major upgrade to the library that addresses all outstanding bugs, +feature requests, and rough edges. It is based on a merge of master which +contains various fixes to issues found over the years and the v2 branch which +contains some backwards-incompatible features like the ability to remove cron +jobs. In addition, v3 adds support for Go Modules, cleans up rough edges like +the timezone support, and fixes a number of bugs. + +New features: + +- Support for Go modules. Callers must now import this library as + `github.com/robfig/cron/v3`, instead of `gopkg.in/...` + +- Fixed bugs: + - 0f01e6b parser: fix combining of Dow and Dom (#70) + - dbf3220 adjust times when rolling the clock forward to handle non-existent midnight (#157) + - eeecf15 spec_test.go: ensure an error is returned on 0 increment (#144) + - 70971dc cron.Entries(): update request for snapshot to include a reply channel (#97) + - 1cba5e6 cron: fix: removing a job causes the next scheduled job to run too late (#206) + +- Standard cron spec parsing by default (first field is "minute"), with an easy + way to opt into the seconds field (quartz-compatible). Although, note that the + year field (optional in Quartz) is not supported. + +- Extensible, key/value logging via an interface that complies with + the https://github.com/go-logr/logr project. + +- The new Chain & JobWrapper types allow you to install "interceptors" to add + cross-cutting behavior like the following: + - Recover any panics from jobs + - Delay a job's execution if the previous run hasn't completed yet + - Skip a job's execution if the previous run hasn't completed yet + - Log each job's invocations + - Notification when jobs are completed + +It is backwards incompatible with both v1 and v2. These updates are required: + +- The v1 branch accepted an optional seconds field at the beginning of the cron + spec. This is non-standard and has led to a lot of confusion. The new default + parser conforms to the standard as described by [the Cron wikipedia page]. + + UPDATING: To retain the old behavior, construct your Cron with a custom + parser: +```go +// Seconds field, required +cron.New(cron.WithSeconds()) + +// Seconds field, optional +cron.New(cron.WithParser(cron.NewParser( + cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor, +))) +``` +- The Cron type now accepts functional options on construction rather than the + previous ad-hoc behavior modification mechanisms (setting a field, calling a setter). + + UPDATING: Code that sets Cron.ErrorLogger or calls Cron.SetLocation must be + updated to provide those values on construction. + +- CRON_TZ is now the recommended way to specify the timezone of a single + schedule, which is sanctioned by the specification. The legacy "TZ=" prefix + will continue to be supported since it is unambiguous and easy to do so. + + UPDATING: No update is required. + +- By default, cron will no longer recover panics in jobs that it runs. + Recovering can be surprising (see issue #192) and seems to be at odds with + typical behavior of libraries. Relatedly, the `cron.WithPanicLogger` option + has been removed to accommodate the more general JobWrapper type. + + UPDATING: To opt into panic recovery and configure the panic logger: +```go +cron.New(cron.WithChain( + cron.Recover(logger), // or use cron.DefaultLogger +)) +``` +- In adding support for https://github.com/go-logr/logr, `cron.WithVerboseLogger` was + removed, since it is duplicative with the leveled logging. + + UPDATING: Callers should use `WithLogger` and specify a logger that does not + discard `Info` logs. For convenience, one is provided that wraps `*log.Logger`: +```go +cron.New( + cron.WithLogger(cron.VerbosePrintfLogger(logger))) +``` + +### Background - Cron spec format + +There are two cron spec formats in common usage: + +- The "standard" cron format, described on [the Cron wikipedia page] and used by + the cron Linux system utility. + +- The cron format used by [the Quartz Scheduler], commonly used for scheduled + jobs in Java software + +[the Cron wikipedia page]: https://en.wikipedia.org/wiki/Cron +[the Quartz Scheduler]: http://www.quartz-scheduler.org/documentation/quartz-2.3.0/tutorials/tutorial-lesson-06.html + +The original version of this package included an optional "seconds" field, which +made it incompatible with both of these formats. Now, the "standard" format is +the default format accepted, and the Quartz format is opt-in. diff --git a/vendor/github.com/robfig/cron/v3/chain.go b/vendor/github.com/robfig/cron/v3/chain.go new file mode 100644 index 00000000..9c087b7b --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/chain.go @@ -0,0 +1,92 @@ +package cron + +import ( + "fmt" + "runtime" + "sync" + "time" +) + +// JobWrapper decorates the given Job with some behavior. +type JobWrapper func(Job) Job + +// Chain is a sequence of JobWrappers that decorates submitted jobs with +// cross-cutting behaviors like logging or synchronization. +type Chain struct { + wrappers []JobWrapper +} + +// NewChain returns a Chain consisting of the given JobWrappers. +func NewChain(c ...JobWrapper) Chain { + return Chain{c} +} + +// Then decorates the given job with all JobWrappers in the chain. +// +// This: +// NewChain(m1, m2, m3).Then(job) +// is equivalent to: +// m1(m2(m3(job))) +func (c Chain) Then(j Job) Job { + for i := range c.wrappers { + j = c.wrappers[len(c.wrappers)-i-1](j) + } + return j +} + +// Recover panics in wrapped jobs and log them with the provided logger. +func Recover(logger Logger) JobWrapper { + return func(j Job) Job { + return FuncJob(func() { + defer func() { + if r := recover(); r != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + err, ok := r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + logger.Error(err, "panic", "stack", "...\n"+string(buf)) + } + }() + j.Run() + }) + } +} + +// DelayIfStillRunning serializes jobs, delaying subsequent runs until the +// previous one is complete. Jobs running after a delay of more than a minute +// have the delay logged at Info. +func DelayIfStillRunning(logger Logger) JobWrapper { + return func(j Job) Job { + var mu sync.Mutex + return FuncJob(func() { + start := time.Now() + mu.Lock() + defer mu.Unlock() + if dur := time.Since(start); dur > time.Minute { + logger.Info("delay", "duration", dur) + } + j.Run() + }) + } +} + +// SkipIfStillRunning skips an invocation of the Job if a previous invocation is +// still running. It logs skips to the given logger at Info level. +func SkipIfStillRunning(logger Logger) JobWrapper { + return func(j Job) Job { + var ch = make(chan struct{}, 1) + ch <- struct{}{} + return FuncJob(func() { + select { + case v := <-ch: + defer func() { ch <- v }() + j.Run() + default: + logger.Info("skip") + } + }) + } +} diff --git a/vendor/github.com/robfig/cron/v3/chain_test.go b/vendor/github.com/robfig/cron/v3/chain_test.go new file mode 100644 index 00000000..ec910975 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/chain_test.go @@ -0,0 +1,242 @@ +package cron + +import ( + "io/ioutil" + "log" + "reflect" + "sync" + "testing" + "time" +) + +func appendingJob(slice *[]int, value int) Job { + var m sync.Mutex + return FuncJob(func() { + m.Lock() + *slice = append(*slice, value) + m.Unlock() + }) +} + +func appendingWrapper(slice *[]int, value int) JobWrapper { + return func(j Job) Job { + return FuncJob(func() { + appendingJob(slice, value).Run() + j.Run() + }) + } +} + +func TestChain(t *testing.T) { + var nums []int + var ( + append1 = appendingWrapper(&nums, 1) + append2 = appendingWrapper(&nums, 2) + append3 = appendingWrapper(&nums, 3) + append4 = appendingJob(&nums, 4) + ) + NewChain(append1, append2, append3).Then(append4).Run() + if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) { + t.Error("unexpected order of calls:", nums) + } +} + +func TestChainRecover(t *testing.T) { + panickingJob := FuncJob(func() { + panic("panickingJob panics") + }) + + t.Run("panic exits job by default", func(t *testing.T) { + defer func() { + if err := recover(); err == nil { + t.Errorf("panic expected, but none received") + } + }() + NewChain().Then(panickingJob). + Run() + }) + + t.Run("Recovering JobWrapper recovers", func(t *testing.T) { + NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))). + Then(panickingJob). + Run() + }) + + t.Run("composed with the *IfStillRunning wrappers", func(t *testing.T) { + NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))). + Then(panickingJob). + Run() + }) +} + +type countJob struct { + m sync.Mutex + started int + done int + delay time.Duration +} + +func (j *countJob) Run() { + j.m.Lock() + j.started++ + j.m.Unlock() + time.Sleep(j.delay) + j.m.Lock() + j.done++ + j.m.Unlock() +} + +func (j *countJob) Started() int { + defer j.m.Unlock() + j.m.Lock() + return j.started +} + +func (j *countJob) Done() int { + defer j.m.Unlock() + j.m.Lock() + return j.done +} + +func TestChainDelayIfStillRunning(t *testing.T) { + + t.Run("runs immediately", func(t *testing.T) { + var j countJob + wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j) + go wrappedJob.Run() + time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete. + if c := j.Done(); c != 1 { + t.Errorf("expected job run once, immediately, got %d", c) + } + }) + + t.Run("second run immediate if first done", func(t *testing.T) { + var j countJob + wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j) + go func() { + go wrappedJob.Run() + time.Sleep(time.Millisecond) + go wrappedJob.Run() + }() + time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete. + if c := j.Done(); c != 2 { + t.Errorf("expected job run twice, immediately, got %d", c) + } + }) + + t.Run("second run delayed if first not done", func(t *testing.T) { + var j countJob + j.delay = 10 * time.Millisecond + wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j) + go func() { + go wrappedJob.Run() + time.Sleep(time.Millisecond) + go wrappedJob.Run() + }() + + // After 5ms, the first job is still in progress, and the second job was + // run but should be waiting for it to finish. + time.Sleep(5 * time.Millisecond) + started, done := j.Started(), j.Done() + if started != 1 || done != 0 { + t.Error("expected first job started, but not finished, got", started, done) + } + + // Verify that the second job completes. + time.Sleep(25 * time.Millisecond) + started, done = j.Started(), j.Done() + if started != 2 || done != 2 { + t.Error("expected both jobs done, got", started, done) + } + }) + +} + +func TestChainSkipIfStillRunning(t *testing.T) { + + t.Run("runs immediately", func(t *testing.T) { + var j countJob + wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j) + go wrappedJob.Run() + time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete. + if c := j.Done(); c != 1 { + t.Errorf("expected job run once, immediately, got %d", c) + } + }) + + t.Run("second run immediate if first done", func(t *testing.T) { + var j countJob + wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j) + go func() { + go wrappedJob.Run() + time.Sleep(time.Millisecond) + go wrappedJob.Run() + }() + time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete. + if c := j.Done(); c != 2 { + t.Errorf("expected job run twice, immediately, got %d", c) + } + }) + + t.Run("second run skipped if first not done", func(t *testing.T) { + var j countJob + j.delay = 10 * time.Millisecond + wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j) + go func() { + go wrappedJob.Run() + time.Sleep(time.Millisecond) + go wrappedJob.Run() + }() + + // After 5ms, the first job is still in progress, and the second job was + // aleady skipped. + time.Sleep(5 * time.Millisecond) + started, done := j.Started(), j.Done() + if started != 1 || done != 0 { + t.Error("expected first job started, but not finished, got", started, done) + } + + // Verify that the first job completes and second does not run. + time.Sleep(25 * time.Millisecond) + started, done = j.Started(), j.Done() + if started != 1 || done != 1 { + t.Error("expected second job skipped, got", started, done) + } + }) + + t.Run("skip 10 jobs on rapid fire", func(t *testing.T) { + var j countJob + j.delay = 10 * time.Millisecond + wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j) + for i := 0; i < 11; i++ { + go wrappedJob.Run() + } + time.Sleep(200 * time.Millisecond) + done := j.Done() + if done != 1 { + t.Error("expected 1 jobs executed, 10 jobs dropped, got", done) + } + }) + + t.Run("different jobs independent", func(t *testing.T) { + var j1, j2 countJob + j1.delay = 10 * time.Millisecond + j2.delay = 10 * time.Millisecond + chain := NewChain(SkipIfStillRunning(DiscardLogger)) + wrappedJob1 := chain.Then(&j1) + wrappedJob2 := chain.Then(&j2) + for i := 0; i < 11; i++ { + go wrappedJob1.Run() + go wrappedJob2.Run() + } + time.Sleep(100 * time.Millisecond) + var ( + done1 = j1.Done() + done2 = j2.Done() + ) + if done1 != 1 || done2 != 1 { + t.Error("expected both jobs executed once, got", done1, "and", done2) + } + }) + +} diff --git a/vendor/github.com/robfig/cron/v3/constantdelay.go b/vendor/github.com/robfig/cron/v3/constantdelay.go new file mode 100644 index 00000000..cd6e7b1b --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/constantdelay.go @@ -0,0 +1,27 @@ +package cron + +import "time" + +// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes". +// It does not support jobs more frequent than once a second. +type ConstantDelaySchedule struct { + Delay time.Duration +} + +// Every returns a crontab Schedule that activates once every duration. +// Delays of less than a second are not supported (will round up to 1 second). +// Any fields less than a Second are truncated. +func Every(duration time.Duration) ConstantDelaySchedule { + if duration < time.Second { + duration = time.Second + } + return ConstantDelaySchedule{ + Delay: duration - time.Duration(duration.Nanoseconds())%time.Second, + } +} + +// Next returns the next time this should be run. +// This rounds so that the next activation time will be on the second. +func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time { + return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond) +} diff --git a/vendor/github.com/robfig/cron/v3/constantdelay_test.go b/vendor/github.com/robfig/cron/v3/constantdelay_test.go new file mode 100644 index 00000000..f43a58ad --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/constantdelay_test.go @@ -0,0 +1,54 @@ +package cron + +import ( + "testing" + "time" +) + +func TestConstantDelayNext(t *testing.T) { + tests := []struct { + time string + delay time.Duration + expected string + }{ + // Simple cases + {"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"}, + {"Mon Jul 9 14:59 2012", 15 * time.Minute, "Mon Jul 9 15:14 2012"}, + {"Mon Jul 9 14:59:59 2012", 15 * time.Minute, "Mon Jul 9 15:14:59 2012"}, + + // Wrap around hours + {"Mon Jul 9 15:45 2012", 35 * time.Minute, "Mon Jul 9 16:20 2012"}, + + // Wrap around days + {"Mon Jul 9 23:46 2012", 14 * time.Minute, "Tue Jul 10 00:00 2012"}, + {"Mon Jul 9 23:45 2012", 35 * time.Minute, "Tue Jul 10 00:20 2012"}, + {"Mon Jul 9 23:35:51 2012", 44*time.Minute + 24*time.Second, "Tue Jul 10 00:20:15 2012"}, + {"Mon Jul 9 23:35:51 2012", 25*time.Hour + 44*time.Minute + 24*time.Second, "Thu Jul 11 01:20:15 2012"}, + + // Wrap around months + {"Mon Jul 9 23:35 2012", 91*24*time.Hour + 25*time.Minute, "Thu Oct 9 00:00 2012"}, + + // Wrap around minute, hour, day, month, and year + {"Mon Dec 31 23:59:45 2012", 15 * time.Second, "Tue Jan 1 00:00:00 2013"}, + + // Round to nearest second on the delay + {"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"}, + + // Round up to 1 second if the duration is less. + {"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:01 2012"}, + + // Round to nearest second when calculating the next time. + {"Mon Jul 9 14:45:00.005 2012", 15 * time.Minute, "Mon Jul 9 15:00 2012"}, + + // Round to nearest second for both. + {"Mon Jul 9 14:45:00.005 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"}, + } + + for _, c := range tests { + actual := Every(c.delay).Next(getTime(c.time)) + expected := getTime(c.expected) + if actual != expected { + t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.delay, expected, actual) + } + } +} diff --git a/vendor/github.com/robfig/cron/v3/cron.go b/vendor/github.com/robfig/cron/v3/cron.go new file mode 100644 index 00000000..c7e91766 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/cron.go @@ -0,0 +1,355 @@ +package cron + +import ( + "context" + "sort" + "sync" + "time" +) + +// Cron keeps track of any number of entries, invoking the associated func as +// specified by the schedule. It may be started, stopped, and the entries may +// be inspected while running. +type Cron struct { + entries []*Entry + chain Chain + stop chan struct{} + add chan *Entry + remove chan EntryID + snapshot chan chan []Entry + running bool + logger Logger + runningMu sync.Mutex + location *time.Location + parser ScheduleParser + nextID EntryID + jobWaiter sync.WaitGroup +} + +// ScheduleParser is an interface for schedule spec parsers that return a Schedule +type ScheduleParser interface { + Parse(spec string) (Schedule, error) +} + +// Job is an interface for submitted cron jobs. +type Job interface { + Run() +} + +// Schedule describes a job's duty cycle. +type Schedule interface { + // Next returns the next activation time, later than the given time. + // Next is invoked initially, and then each time the job is run. + Next(time.Time) time.Time +} + +// EntryID identifies an entry within a Cron instance +type EntryID int + +// Entry consists of a schedule and the func to execute on that schedule. +type Entry struct { + // ID is the cron-assigned ID of this entry, which may be used to look up a + // snapshot or remove it. + ID EntryID + + // Schedule on which this job should be run. + Schedule Schedule + + // Next time the job will run, or the zero time if Cron has not been + // started or this entry's schedule is unsatisfiable + Next time.Time + + // Prev is the last time this job was run, or the zero time if never. + Prev time.Time + + // WrappedJob is the thing to run when the Schedule is activated. + WrappedJob Job + + // Job is the thing that was submitted to cron. + // It is kept around so that user code that needs to get at the job later, + // e.g. via Entries() can do so. + Job Job +} + +// Valid returns true if this is not the zero entry. +func (e Entry) Valid() bool { return e.ID != 0 } + +// byTime is a wrapper for sorting the entry array by time +// (with zero time at the end). +type byTime []*Entry + +func (s byTime) Len() int { return len(s) } +func (s byTime) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s byTime) Less(i, j int) bool { + // Two zero times should return false. + // Otherwise, zero is "greater" than any other time. + // (To sort it at the end of the list.) + if s[i].Next.IsZero() { + return false + } + if s[j].Next.IsZero() { + return true + } + return s[i].Next.Before(s[j].Next) +} + +// New returns a new Cron job runner, modified by the given options. +// +// Available Settings +// +// Time Zone +// Description: The time zone in which schedules are interpreted +// Default: time.Local +// +// Parser +// Description: Parser converts cron spec strings into cron.Schedules. +// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron +// +// Chain +// Description: Wrap submitted jobs to customize behavior. +// Default: A chain that recovers panics and logs them to stderr. +// +// See "cron.With*" to modify the default behavior. +func New(opts ...Option) *Cron { + c := &Cron{ + entries: nil, + chain: NewChain(), + add: make(chan *Entry), + stop: make(chan struct{}), + snapshot: make(chan chan []Entry), + remove: make(chan EntryID), + running: false, + runningMu: sync.Mutex{}, + logger: DefaultLogger, + location: time.Local, + parser: standardParser, + } + for _, opt := range opts { + opt(c) + } + return c +} + +// FuncJob is a wrapper that turns a func() into a cron.Job +type FuncJob func() + +func (f FuncJob) Run() { f() } + +// AddFunc adds a func to the Cron to be run on the given schedule. +// The spec is parsed using the time zone of this Cron instance as the default. +// An opaque ID is returned that can be used to later remove it. +func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) { + return c.AddJob(spec, FuncJob(cmd)) +} + +// AddJob adds a Job to the Cron to be run on the given schedule. +// The spec is parsed using the time zone of this Cron instance as the default. +// An opaque ID is returned that can be used to later remove it. +func (c *Cron) AddJob(spec string, cmd Job) (EntryID, error) { + schedule, err := c.parser.Parse(spec) + if err != nil { + return 0, err + } + return c.Schedule(schedule, cmd), nil +} + +// Schedule adds a Job to the Cron to be run on the given schedule. +// The job is wrapped with the configured Chain. +func (c *Cron) Schedule(schedule Schedule, cmd Job) EntryID { + c.runningMu.Lock() + defer c.runningMu.Unlock() + c.nextID++ + entry := &Entry{ + ID: c.nextID, + Schedule: schedule, + WrappedJob: c.chain.Then(cmd), + Job: cmd, + } + if !c.running { + c.entries = append(c.entries, entry) + } else { + c.add <- entry + } + return entry.ID +} + +// Entries returns a snapshot of the cron entries. +func (c *Cron) Entries() []Entry { + c.runningMu.Lock() + defer c.runningMu.Unlock() + if c.running { + replyChan := make(chan []Entry, 1) + c.snapshot <- replyChan + return <-replyChan + } + return c.entrySnapshot() +} + +// Location gets the time zone location +func (c *Cron) Location() *time.Location { + return c.location +} + +// Entry returns a snapshot of the given entry, or nil if it couldn't be found. +func (c *Cron) Entry(id EntryID) Entry { + for _, entry := range c.Entries() { + if id == entry.ID { + return entry + } + } + return Entry{} +} + +// Remove an entry from being run in the future. +func (c *Cron) Remove(id EntryID) { + c.runningMu.Lock() + defer c.runningMu.Unlock() + if c.running { + c.remove <- id + } else { + c.removeEntry(id) + } +} + +// Start the cron scheduler in its own goroutine, or no-op if already started. +func (c *Cron) Start() { + c.runningMu.Lock() + defer c.runningMu.Unlock() + if c.running { + return + } + c.running = true + go c.run() +} + +// Run the cron scheduler, or no-op if already running. +func (c *Cron) Run() { + c.runningMu.Lock() + if c.running { + c.runningMu.Unlock() + return + } + c.running = true + c.runningMu.Unlock() + c.run() +} + +// run the scheduler.. this is private just due to the need to synchronize +// access to the 'running' state variable. +func (c *Cron) run() { + c.logger.Info("start") + + // Figure out the next activation times for each entry. + now := c.now() + for _, entry := range c.entries { + entry.Next = entry.Schedule.Next(now) + c.logger.Info("schedule", "now", now, "entry", entry.ID, "next", entry.Next) + } + + for { + // Determine the next entry to run. + sort.Sort(byTime(c.entries)) + + var timer *time.Timer + if len(c.entries) == 0 || c.entries[0].Next.IsZero() { + // If there are no entries yet, just sleep - it still handles new entries + // and stop requests. + timer = time.NewTimer(100000 * time.Hour) + } else { + timer = time.NewTimer(c.entries[0].Next.Sub(now)) + } + + for { + select { + case now = <-timer.C: + now = now.In(c.location) + c.logger.Info("wake", "now", now) + + // Run every entry whose next time was less than now + for _, e := range c.entries { + if e.Next.After(now) || e.Next.IsZero() { + break + } + c.startJob(e.WrappedJob) + e.Prev = e.Next + e.Next = e.Schedule.Next(now) + c.logger.Info("run", "now", now, "entry", e.ID, "next", e.Next) + } + + case newEntry := <-c.add: + timer.Stop() + now = c.now() + newEntry.Next = newEntry.Schedule.Next(now) + c.entries = append(c.entries, newEntry) + c.logger.Info("added", "now", now, "entry", newEntry.ID, "next", newEntry.Next) + + case replyChan := <-c.snapshot: + replyChan <- c.entrySnapshot() + continue + + case <-c.stop: + timer.Stop() + c.logger.Info("stop") + return + + case id := <-c.remove: + timer.Stop() + now = c.now() + c.removeEntry(id) + c.logger.Info("removed", "entry", id) + } + + break + } + } +} + +// startJob runs the given job in a new goroutine. +func (c *Cron) startJob(j Job) { + c.jobWaiter.Add(1) + go func() { + defer c.jobWaiter.Done() + j.Run() + }() +} + +// now returns current time in c location +func (c *Cron) now() time.Time { + return time.Now().In(c.location) +} + +// Stop stops the cron scheduler if it is running; otherwise it does nothing. +// A context is returned so the caller can wait for running jobs to complete. +func (c *Cron) Stop() context.Context { + c.runningMu.Lock() + defer c.runningMu.Unlock() + if c.running { + c.stop <- struct{}{} + c.running = false + } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + c.jobWaiter.Wait() + cancel() + }() + return ctx +} + +// entrySnapshot returns a copy of the current cron entry list. +func (c *Cron) entrySnapshot() []Entry { + var entries = make([]Entry, len(c.entries)) + for i, e := range c.entries { + entries[i] = *e + } + return entries +} + +func (c *Cron) removeEntry(id EntryID) { + var entries []*Entry + for _, e := range c.entries { + if e.ID != id { + entries = append(entries, e) + } + } + c.entries = entries +} diff --git a/vendor/github.com/robfig/cron/v3/cron_test.go b/vendor/github.com/robfig/cron/v3/cron_test.go new file mode 100644 index 00000000..36f06bf7 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/cron_test.go @@ -0,0 +1,702 @@ +package cron + +import ( + "bytes" + "fmt" + "log" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Many tests schedule a job for every second, and then wait at most a second +// for it to run. This amount is just slightly larger than 1 second to +// compensate for a few milliseconds of runtime. +const OneSecond = 1*time.Second + 50*time.Millisecond + +type syncWriter struct { + wr bytes.Buffer + m sync.Mutex +} + +func (sw *syncWriter) Write(data []byte) (n int, err error) { + sw.m.Lock() + n, err = sw.wr.Write(data) + sw.m.Unlock() + return +} + +func (sw *syncWriter) String() string { + sw.m.Lock() + defer sw.m.Unlock() + return sw.wr.String() +} + +func newBufLogger(sw *syncWriter) Logger { + return PrintfLogger(log.New(sw, "", log.LstdFlags)) +} + +func TestFuncPanicRecovery(t *testing.T) { + var buf syncWriter + cron := New(WithParser(secondParser), + WithChain(Recover(newBufLogger(&buf)))) + cron.Start() + defer cron.Stop() + cron.AddFunc("* * * * * ?", func() { + panic("YOLO") + }) + + select { + case <-time.After(OneSecond): + if !strings.Contains(buf.String(), "YOLO") { + t.Error("expected a panic to be logged, got none") + } + return + } +} + +type DummyJob struct{} + +func (d DummyJob) Run() { + panic("YOLO") +} + +func TestJobPanicRecovery(t *testing.T) { + var job DummyJob + + var buf syncWriter + cron := New(WithParser(secondParser), + WithChain(Recover(newBufLogger(&buf)))) + cron.Start() + defer cron.Stop() + cron.AddJob("* * * * * ?", job) + + select { + case <-time.After(OneSecond): + if !strings.Contains(buf.String(), "YOLO") { + t.Error("expected a panic to be logged, got none") + } + return + } +} + +// Start and stop cron with no entries. +func TestNoEntries(t *testing.T) { + cron := newWithSeconds() + cron.Start() + + select { + case <-time.After(OneSecond): + t.Fatal("expected cron will be stopped immediately") + case <-stop(cron): + } +} + +// Start, stop, then add an entry. Verify entry doesn't run. +func TestStopCausesJobsToNotRun(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + cron := newWithSeconds() + cron.Start() + cron.Stop() + cron.AddFunc("* * * * * ?", func() { wg.Done() }) + + select { + case <-time.After(OneSecond): + // No job ran! + case <-wait(wg): + t.Fatal("expected stopped cron does not run any job") + } +} + +// Add a job, start cron, expect it runs. +func TestAddBeforeRunning(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + cron := newWithSeconds() + cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.Start() + defer cron.Stop() + + // Give cron 2 seconds to run our job (which is always activated). + select { + case <-time.After(OneSecond): + t.Fatal("expected job runs") + case <-wait(wg): + } +} + +// Start cron, add a job, expect it runs. +func TestAddWhileRunning(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + cron := newWithSeconds() + cron.Start() + defer cron.Stop() + cron.AddFunc("* * * * * ?", func() { wg.Done() }) + + select { + case <-time.After(OneSecond): + t.Fatal("expected job runs") + case <-wait(wg): + } +} + +// Test for #34. Adding a job after calling start results in multiple job invocations +func TestAddWhileRunningWithDelay(t *testing.T) { + cron := newWithSeconds() + cron.Start() + defer cron.Stop() + time.Sleep(5 * time.Second) + var calls int64 + cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) }) + + <-time.After(OneSecond) + if atomic.LoadInt64(&calls) != 1 { + t.Errorf("called %d times, expected 1\n", calls) + } +} + +// Add a job, remove a job, start cron, expect nothing runs. +func TestRemoveBeforeRunning(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + cron := newWithSeconds() + id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.Remove(id) + cron.Start() + defer cron.Stop() + + select { + case <-time.After(OneSecond): + // Success, shouldn't run + case <-wait(wg): + t.FailNow() + } +} + +// Start cron, add a job, remove it, expect it doesn't run. +func TestRemoveWhileRunning(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + cron := newWithSeconds() + cron.Start() + defer cron.Stop() + id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.Remove(id) + + select { + case <-time.After(OneSecond): + case <-wait(wg): + t.FailNow() + } +} + +// Test timing with Entries. +func TestSnapshotEntries(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + cron := New() + cron.AddFunc("@every 2s", func() { wg.Done() }) + cron.Start() + defer cron.Stop() + + // Cron should fire in 2 seconds. After 1 second, call Entries. + select { + case <-time.After(OneSecond): + cron.Entries() + } + + // Even though Entries was called, the cron should fire at the 2 second mark. + select { + case <-time.After(OneSecond): + t.Error("expected job runs at 2 second mark") + case <-wait(wg): + } +} + +// Test that the entries are correctly sorted. +// Add a bunch of long-in-the-future entries, and an immediate entry, and ensure +// that the immediate entry runs immediately. +// Also: Test that multiple jobs run in the same instant. +func TestMultipleEntries(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + + cron := newWithSeconds() + cron.AddFunc("0 0 0 1 1 ?", func() {}) + cron.AddFunc("* * * * * ?", func() { wg.Done() }) + id1, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() }) + id2, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() }) + cron.AddFunc("0 0 0 31 12 ?", func() {}) + cron.AddFunc("* * * * * ?", func() { wg.Done() }) + + cron.Remove(id1) + cron.Start() + cron.Remove(id2) + defer cron.Stop() + + select { + case <-time.After(OneSecond): + t.Error("expected job run in proper order") + case <-wait(wg): + } +} + +// Test running the same job twice. +func TestRunningJobTwice(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + + cron := newWithSeconds() + cron.AddFunc("0 0 0 1 1 ?", func() {}) + cron.AddFunc("0 0 0 31 12 ?", func() {}) + cron.AddFunc("* * * * * ?", func() { wg.Done() }) + + cron.Start() + defer cron.Stop() + + select { + case <-time.After(2 * OneSecond): + t.Error("expected job fires 2 times") + case <-wait(wg): + } +} + +func TestRunningMultipleSchedules(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + + cron := newWithSeconds() + cron.AddFunc("0 0 0 1 1 ?", func() {}) + cron.AddFunc("0 0 0 31 12 ?", func() {}) + cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.Schedule(Every(time.Minute), FuncJob(func() {})) + cron.Schedule(Every(time.Second), FuncJob(func() { wg.Done() })) + cron.Schedule(Every(time.Hour), FuncJob(func() {})) + + cron.Start() + defer cron.Stop() + + select { + case <-time.After(2 * OneSecond): + t.Error("expected job fires 2 times") + case <-wait(wg): + } +} + +// Test that the cron is run in the local time zone (as opposed to UTC). +func TestLocalTimezone(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + + now := time.Now() + // FIX: Issue #205 + // This calculation doesn't work in seconds 58 or 59. + // Take the easy way out and sleep. + if now.Second() >= 58 { + time.Sleep(2 * time.Second) + now = time.Now() + } + spec := fmt.Sprintf("%d,%d %d %d %d %d ?", + now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month()) + + cron := newWithSeconds() + cron.AddFunc(spec, func() { wg.Done() }) + cron.Start() + defer cron.Stop() + + select { + case <-time.After(OneSecond * 2): + t.Error("expected job fires 2 times") + case <-wait(wg): + } +} + +// Test that the cron is run in the given time zone (as opposed to local). +func TestNonLocalTimezone(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + + loc, err := time.LoadLocation("Atlantic/Cape_Verde") + if err != nil { + fmt.Printf("Failed to load time zone Atlantic/Cape_Verde: %+v", err) + t.Fail() + } + + now := time.Now().In(loc) + // FIX: Issue #205 + // This calculation doesn't work in seconds 58 or 59. + // Take the easy way out and sleep. + if now.Second() >= 58 { + time.Sleep(2 * time.Second) + now = time.Now().In(loc) + } + spec := fmt.Sprintf("%d,%d %d %d %d %d ?", + now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month()) + + cron := New(WithLocation(loc), WithParser(secondParser)) + cron.AddFunc(spec, func() { wg.Done() }) + cron.Start() + defer cron.Stop() + + select { + case <-time.After(OneSecond * 2): + t.Error("expected job fires 2 times") + case <-wait(wg): + } +} + +// Test that calling stop before start silently returns without +// blocking the stop channel. +func TestStopWithoutStart(t *testing.T) { + cron := New() + cron.Stop() +} + +type testJob struct { + wg *sync.WaitGroup + name string +} + +func (t testJob) Run() { + t.wg.Done() +} + +// Test that adding an invalid job spec returns an error +func TestInvalidJobSpec(t *testing.T) { + cron := New() + _, err := cron.AddJob("this will not parse", nil) + if err == nil { + t.Errorf("expected an error with invalid spec, got nil") + } +} + +// Test blocking run method behaves as Start() +func TestBlockingRun(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + cron := newWithSeconds() + cron.AddFunc("* * * * * ?", func() { wg.Done() }) + + var unblockChan = make(chan struct{}) + + go func() { + cron.Run() + close(unblockChan) + }() + defer cron.Stop() + + select { + case <-time.After(OneSecond): + t.Error("expected job fires") + case <-unblockChan: + t.Error("expected that Run() blocks") + case <-wait(wg): + } +} + +// Test that double-running is a no-op +func TestStartNoop(t *testing.T) { + var tickChan = make(chan struct{}, 2) + + cron := newWithSeconds() + cron.AddFunc("* * * * * ?", func() { + tickChan <- struct{}{} + }) + + cron.Start() + defer cron.Stop() + + // Wait for the first firing to ensure the runner is going + <-tickChan + + cron.Start() + + <-tickChan + + // Fail if this job fires again in a short period, indicating a double-run + select { + case <-time.After(time.Millisecond): + case <-tickChan: + t.Error("expected job fires exactly twice") + } +} + +// Simple test using Runnables. +func TestJob(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + cron := newWithSeconds() + cron.AddJob("0 0 0 30 Feb ?", testJob{wg, "job0"}) + cron.AddJob("0 0 0 1 1 ?", testJob{wg, "job1"}) + job2, _ := cron.AddJob("* * * * * ?", testJob{wg, "job2"}) + cron.AddJob("1 0 0 1 1 ?", testJob{wg, "job3"}) + cron.Schedule(Every(5*time.Second+5*time.Nanosecond), testJob{wg, "job4"}) + job5 := cron.Schedule(Every(5*time.Minute), testJob{wg, "job5"}) + + // Test getting an Entry pre-Start. + if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" { + t.Error("wrong job retrieved:", actualName) + } + if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" { + t.Error("wrong job retrieved:", actualName) + } + + cron.Start() + defer cron.Stop() + + select { + case <-time.After(OneSecond): + t.FailNow() + case <-wait(wg): + } + + // Ensure the entries are in the right order. + expecteds := []string{"job2", "job4", "job5", "job1", "job3", "job0"} + + var actuals []string + for _, entry := range cron.Entries() { + actuals = append(actuals, entry.Job.(testJob).name) + } + + for i, expected := range expecteds { + if actuals[i] != expected { + t.Fatalf("Jobs not in the right order. (expected) %s != %s (actual)", expecteds, actuals) + } + } + + // Test getting Entries. + if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" { + t.Error("wrong job retrieved:", actualName) + } + if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" { + t.Error("wrong job retrieved:", actualName) + } +} + +// Issue #206 +// Ensure that the next run of a job after removing an entry is accurate. +func TestScheduleAfterRemoval(t *testing.T) { + var wg1 sync.WaitGroup + var wg2 sync.WaitGroup + wg1.Add(1) + wg2.Add(1) + + // The first time this job is run, set a timer and remove the other job + // 750ms later. Correct behavior would be to still run the job again in + // 250ms, but the bug would cause it to run instead 1s later. + + var calls int + var mu sync.Mutex + + cron := newWithSeconds() + hourJob := cron.Schedule(Every(time.Hour), FuncJob(func() {})) + cron.Schedule(Every(time.Second), FuncJob(func() { + mu.Lock() + defer mu.Unlock() + switch calls { + case 0: + wg1.Done() + calls++ + case 1: + time.Sleep(750 * time.Millisecond) + cron.Remove(hourJob) + calls++ + case 2: + calls++ + wg2.Done() + case 3: + panic("unexpected 3rd call") + } + })) + + cron.Start() + defer cron.Stop() + + // the first run might be any length of time 0 - 1s, since the schedule + // rounds to the second. wait for the first run to true up. + wg1.Wait() + + select { + case <-time.After(2 * OneSecond): + t.Error("expected job fires 2 times") + case <-wait(&wg2): + } +} + +type ZeroSchedule struct{} + +func (*ZeroSchedule) Next(time.Time) time.Time { + return time.Time{} +} + +// Tests that job without time does not run +func TestJobWithZeroTimeDoesNotRun(t *testing.T) { + cron := newWithSeconds() + var calls int64 + cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) }) + cron.Schedule(new(ZeroSchedule), FuncJob(func() { t.Error("expected zero task will not run") })) + cron.Start() + defer cron.Stop() + <-time.After(OneSecond) + if atomic.LoadInt64(&calls) != 1 { + t.Errorf("called %d times, expected 1\n", calls) + } +} + +func TestStopAndWait(t *testing.T) { + t.Run("nothing running, returns immediately", func(t *testing.T) { + cron := newWithSeconds() + cron.Start() + ctx := cron.Stop() + select { + case <-ctx.Done(): + case <-time.After(time.Millisecond): + t.Error("context was not done immediately") + } + }) + + t.Run("repeated calls to Stop", func(t *testing.T) { + cron := newWithSeconds() + cron.Start() + _ = cron.Stop() + time.Sleep(time.Millisecond) + ctx := cron.Stop() + select { + case <-ctx.Done(): + case <-time.After(time.Millisecond): + t.Error("context was not done immediately") + } + }) + + t.Run("a couple fast jobs added, still returns immediately", func(t *testing.T) { + cron := newWithSeconds() + cron.AddFunc("* * * * * *", func() {}) + cron.Start() + cron.AddFunc("* * * * * *", func() {}) + cron.AddFunc("* * * * * *", func() {}) + cron.AddFunc("* * * * * *", func() {}) + time.Sleep(time.Second) + ctx := cron.Stop() + select { + case <-ctx.Done(): + case <-time.After(time.Millisecond): + t.Error("context was not done immediately") + } + }) + + t.Run("a couple fast jobs and a slow job added, waits for slow job", func(t *testing.T) { + cron := newWithSeconds() + cron.AddFunc("* * * * * *", func() {}) + cron.Start() + cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) }) + cron.AddFunc("* * * * * *", func() {}) + time.Sleep(time.Second) + + ctx := cron.Stop() + + // Verify that it is not done for at least 750ms + select { + case <-ctx.Done(): + t.Error("context was done too quickly immediately") + case <-time.After(750 * time.Millisecond): + // expected, because the job sleeping for 1 second is still running + } + + // Verify that it IS done in the next 500ms (giving 250ms buffer) + select { + case <-ctx.Done(): + // expected + case <-time.After(1500 * time.Millisecond): + t.Error("context not done after job should have completed") + } + }) + + t.Run("repeated calls to stop, waiting for completion and after", func(t *testing.T) { + cron := newWithSeconds() + cron.AddFunc("* * * * * *", func() {}) + cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) }) + cron.Start() + cron.AddFunc("* * * * * *", func() {}) + time.Sleep(time.Second) + ctx := cron.Stop() + ctx2 := cron.Stop() + + // Verify that it is not done for at least 1500ms + select { + case <-ctx.Done(): + t.Error("context was done too quickly immediately") + case <-ctx2.Done(): + t.Error("context2 was done too quickly immediately") + case <-time.After(1500 * time.Millisecond): + // expected, because the job sleeping for 2 seconds is still running + } + + // Verify that it IS done in the next 1s (giving 500ms buffer) + select { + case <-ctx.Done(): + // expected + case <-time.After(time.Second): + t.Error("context not done after job should have completed") + } + + // Verify that ctx2 is also done. + select { + case <-ctx2.Done(): + // expected + case <-time.After(time.Millisecond): + t.Error("context2 not done even though context1 is") + } + + // Verify that a new context retrieved from stop is immediately done. + ctx3 := cron.Stop() + select { + case <-ctx3.Done(): + // expected + case <-time.After(time.Millisecond): + t.Error("context not done even when cron Stop is completed") + } + + }) +} + +func TestMultiThreadedStartAndStop(t *testing.T) { + cron := New() + go cron.Run() + time.Sleep(2 * time.Millisecond) + cron.Stop() +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} + +func stop(cron *Cron) chan bool { + ch := make(chan bool) + go func() { + cron.Stop() + ch <- true + }() + return ch +} + +// newWithSeconds returns a Cron with the seconds field enabled. +func newWithSeconds() *Cron { + return New(WithParser(secondParser), WithChain()) +} diff --git a/vendor/github.com/robfig/cron/v3/doc.go b/vendor/github.com/robfig/cron/v3/doc.go new file mode 100644 index 00000000..fa5d08b4 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/doc.go @@ -0,0 +1,231 @@ +/* +Package cron implements a cron spec parser and job runner. + +Installation + +To download the specific tagged release, run: + + go get github.com/robfig/cron/v3@v3.0.0 + +Import it in your program as: + + import "github.com/robfig/cron/v3" + +It requires Go 1.11 or later due to usage of Go Modules. + +Usage + +Callers may register Funcs to be invoked on a given schedule. Cron will run +them in their own goroutines. + + c := cron.New() + c.AddFunc("30 * * * *", func() { fmt.Println("Every hour on the half hour") }) + c.AddFunc("30 3-6,20-23 * * *", func() { fmt.Println(".. in the range 3-6am, 8-11pm") }) + c.AddFunc("CRON_TZ=Asia/Tokyo 30 04 * * *", func() { fmt.Println("Runs at 04:30 Tokyo time every day") }) + c.AddFunc("@hourly", func() { fmt.Println("Every hour, starting an hour from now") }) + c.AddFunc("@every 1h30m", func() { fmt.Println("Every hour thirty, starting an hour thirty from now") }) + c.Start() + .. + // Funcs are invoked in their own goroutine, asynchronously. + ... + // Funcs may also be added to a running Cron + c.AddFunc("@daily", func() { fmt.Println("Every day") }) + .. + // Inspect the cron job entries' next and previous run times. + inspect(c.Entries()) + .. + c.Stop() // Stop the scheduler (does not stop any jobs already running). + +CRON Expression Format + +A cron expression represents a set of times, using 5 space-separated fields. + + Field name | Mandatory? | Allowed values | Allowed special characters + ---------- | ---------- | -------------- | -------------------------- + Minutes | Yes | 0-59 | * / , - + Hours | Yes | 0-23 | * / , - + Day of month | Yes | 1-31 | * / , - ? + Month | Yes | 1-12 or JAN-DEC | * / , - + Day of week | Yes | 0-6 or SUN-SAT | * / , - ? + +Month and Day-of-week field values are case insensitive. "SUN", "Sun", and +"sun" are equally accepted. + +The specific interpretation of the format is based on the Cron Wikipedia page: +https://en.wikipedia.org/wiki/Cron + +Alternative Formats + +Alternative Cron expression formats support other fields like seconds. You can +implement that by creating a custom Parser as follows. + + cron.New( + cron.WithParser( + cron.NewParser( + cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor))) + +Since adding Seconds is the most common modification to the standard cron spec, +cron provides a builtin function to do that, which is equivalent to the custom +parser you saw earlier, except that its seconds field is REQUIRED: + + cron.New(cron.WithSeconds()) + +That emulates Quartz, the most popular alternative Cron schedule format: +http://www.quartz-scheduler.org/documentation/quartz-2.x/tutorials/crontrigger.html + +Special Characters + +Asterisk ( * ) + +The asterisk indicates that the cron expression will match for all values of the +field; e.g., using an asterisk in the 5th field (month) would indicate every +month. + +Slash ( / ) + +Slashes are used to describe increments of ranges. For example 3-59/15 in the +1st field (minutes) would indicate the 3rd minute of the hour and every 15 +minutes thereafter. The form "*\/..." is equivalent to the form "first-last/...", +that is, an increment over the largest possible range of the field. The form +"N/..." is accepted as meaning "N-MAX/...", that is, starting at N, use the +increment until the end of that specific range. It does not wrap around. + +Comma ( , ) + +Commas are used to separate items of a list. For example, using "MON,WED,FRI" in +the 5th field (day of week) would mean Mondays, Wednesdays and Fridays. + +Hyphen ( - ) + +Hyphens are used to define ranges. For example, 9-17 would indicate every +hour between 9am and 5pm inclusive. + +Question mark ( ? ) + +Question mark may be used instead of '*' for leaving either day-of-month or +day-of-week blank. + +Predefined schedules + +You may use one of several pre-defined schedules in place of a cron expression. + + Entry | Description | Equivalent To + ----- | ----------- | ------------- + @yearly (or @annually) | Run once a year, midnight, Jan. 1st | 0 0 1 1 * + @monthly | Run once a month, midnight, first of month | 0 0 1 * * + @weekly | Run once a week, midnight between Sat/Sun | 0 0 * * 0 + @daily (or @midnight) | Run once a day, midnight | 0 0 * * * + @hourly | Run once an hour, beginning of hour | 0 * * * * + +Intervals + +You may also schedule a job to execute at fixed intervals, starting at the time it's added +or cron is run. This is supported by formatting the cron spec like this: + + @every + +where "duration" is a string accepted by time.ParseDuration +(http://golang.org/pkg/time/#ParseDuration). + +For example, "@every 1h30m10s" would indicate a schedule that activates after +1 hour, 30 minutes, 10 seconds, and then every interval after that. + +Note: The interval does not take the job runtime into account. For example, +if a job takes 3 minutes to run, and it is scheduled to run every 5 minutes, +it will have only 2 minutes of idle time between each run. + +Time zones + +By default, all interpretation and scheduling is done in the machine's local +time zone (time.Local). You can specify a different time zone on construction: + + cron.New( + cron.WithLocation(time.UTC)) + +Individual cron schedules may also override the time zone they are to be +interpreted in by providing an additional space-separated field at the beginning +of the cron spec, of the form "CRON_TZ=Asia/Tokyo". + +For example: + + # Runs at 6am in time.Local + cron.New().AddFunc("0 6 * * ?", ...) + + # Runs at 6am in America/New_York + nyc, _ := time.LoadLocation("America/New_York") + c := cron.New(cron.WithLocation(nyc)) + c.AddFunc("0 6 * * ?", ...) + + # Runs at 6am in Asia/Tokyo + cron.New().AddFunc("CRON_TZ=Asia/Tokyo 0 6 * * ?", ...) + + # Runs at 6am in Asia/Tokyo + c := cron.New(cron.WithLocation(nyc)) + c.SetLocation("America/New_York") + c.AddFunc("CRON_TZ=Asia/Tokyo 0 6 * * ?", ...) + +The prefix "TZ=(TIME ZONE)" is also supported for legacy compatibility. + +Be aware that jobs scheduled during daylight-savings leap-ahead transitions will +not be run! + +Job Wrappers + +A Cron runner may be configured with a chain of job wrappers to add +cross-cutting functionality to all submitted jobs. For example, they may be used +to achieve the following effects: + + - Recover any panics from jobs (activated by default) + - Delay a job's execution if the previous run hasn't completed yet + - Skip a job's execution if the previous run hasn't completed yet + - Log each job's invocations + +Install wrappers for all jobs added to a cron using the `cron.WithChain` option: + + cron.New(cron.WithChain( + cron.SkipIfStillRunning(logger), + )) + +Install wrappers for individual jobs by explicitly wrapping them: + + job = cron.NewChain( + cron.SkipIfStillRunning(logger), + ).Then(job) + +Thread safety + +Since the Cron service runs concurrently with the calling code, some amount of +care must be taken to ensure proper synchronization. + +All cron methods are designed to be correctly synchronized as long as the caller +ensures that invocations have a clear happens-before ordering between them. + +Logging + +Cron defines a Logger interface that is a subset of the one defined in +github.com/go-logr/logr. It has two logging levels (Info and Error), and +parameters are key/value pairs. This makes it possible for cron logging to plug +into structured logging systems. An adapter, [Verbose]PrintfLogger, is provided +to wrap the standard library *log.Logger. + +For additional insight into Cron operations, verbose logging may be activated +which will record job runs, scheduling decisions, and added or removed jobs. +Activate it with a one-off logger as follows: + + cron.New( + cron.WithLogger( + cron.VerbosePrintfLogger(log.New(os.Stdout, "cron: ", log.LstdFlags)))) + + +Implementation + +Cron entries are stored in an array, sorted by their next activation time. Cron +sleeps until the next job is due to be run. + +Upon waking: + - it runs each entry that is active on that second + - it calculates the next run times for the jobs that were run + - it re-sorts the array of entries by next activation time. + - it goes to sleep until the soonest job. +*/ +package cron diff --git a/vendor/github.com/robfig/cron/v3/go.mod b/vendor/github.com/robfig/cron/v3/go.mod new file mode 100644 index 00000000..8c95bf47 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/go.mod @@ -0,0 +1,3 @@ +module github.com/robfig/cron/v3 + +go 1.12 diff --git a/vendor/github.com/robfig/cron/v3/logger.go b/vendor/github.com/robfig/cron/v3/logger.go new file mode 100644 index 00000000..b4efcc05 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/logger.go @@ -0,0 +1,86 @@ +package cron + +import ( + "io/ioutil" + "log" + "os" + "strings" + "time" +) + +// DefaultLogger is used by Cron if none is specified. +var DefaultLogger Logger = PrintfLogger(log.New(os.Stdout, "cron: ", log.LstdFlags)) + +// DiscardLogger can be used by callers to discard all log messages. +var DiscardLogger Logger = PrintfLogger(log.New(ioutil.Discard, "", 0)) + +// Logger is the interface used in this package for logging, so that any backend +// can be plugged in. It is a subset of the github.com/go-logr/logr interface. +type Logger interface { + // Info logs routine messages about cron's operation. + Info(msg string, keysAndValues ...interface{}) + // Error logs an error condition. + Error(err error, msg string, keysAndValues ...interface{}) +} + +// PrintfLogger wraps a Printf-based logger (such as the standard library "log") +// into an implementation of the Logger interface which logs errors only. +func PrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger { + return printfLogger{l, false} +} + +// VerbosePrintfLogger wraps a Printf-based logger (such as the standard library +// "log") into an implementation of the Logger interface which logs everything. +func VerbosePrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger { + return printfLogger{l, true} +} + +type printfLogger struct { + logger interface{ Printf(string, ...interface{}) } + logInfo bool +} + +func (pl printfLogger) Info(msg string, keysAndValues ...interface{}) { + if pl.logInfo { + keysAndValues = formatTimes(keysAndValues) + pl.logger.Printf( + formatString(len(keysAndValues)), + append([]interface{}{msg}, keysAndValues...)...) + } +} + +func (pl printfLogger) Error(err error, msg string, keysAndValues ...interface{}) { + keysAndValues = formatTimes(keysAndValues) + pl.logger.Printf( + formatString(len(keysAndValues)+2), + append([]interface{}{msg, "error", err}, keysAndValues...)...) +} + +// formatString returns a logfmt-like format string for the number of +// key/values. +func formatString(numKeysAndValues int) string { + var sb strings.Builder + sb.WriteString("%s") + if numKeysAndValues > 0 { + sb.WriteString(", ") + } + for i := 0; i < numKeysAndValues/2; i++ { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString("%v=%v") + } + return sb.String() +} + +// formatTimes formats any time.Time values as RFC3339. +func formatTimes(keysAndValues []interface{}) []interface{} { + var formattedArgs []interface{} + for _, arg := range keysAndValues { + if t, ok := arg.(time.Time); ok { + arg = t.Format(time.RFC3339) + } + formattedArgs = append(formattedArgs, arg) + } + return formattedArgs +} diff --git a/vendor/github.com/robfig/cron/v3/option.go b/vendor/github.com/robfig/cron/v3/option.go new file mode 100644 index 00000000..09e4278e --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/option.go @@ -0,0 +1,45 @@ +package cron + +import ( + "time" +) + +// Option represents a modification to the default behavior of a Cron. +type Option func(*Cron) + +// WithLocation overrides the timezone of the cron instance. +func WithLocation(loc *time.Location) Option { + return func(c *Cron) { + c.location = loc + } +} + +// WithSeconds overrides the parser used for interpreting job schedules to +// include a seconds field as the first one. +func WithSeconds() Option { + return WithParser(NewParser( + Second | Minute | Hour | Dom | Month | Dow | Descriptor, + )) +} + +// WithParser overrides the parser used for interpreting job schedules. +func WithParser(p ScheduleParser) Option { + return func(c *Cron) { + c.parser = p + } +} + +// WithChain specifies Job wrappers to apply to all jobs added to this cron. +// Refer to the Chain* functions in this package for provided wrappers. +func WithChain(wrappers ...JobWrapper) Option { + return func(c *Cron) { + c.chain = NewChain(wrappers...) + } +} + +// WithLogger uses the provided logger. +func WithLogger(logger Logger) Option { + return func(c *Cron) { + c.logger = logger + } +} diff --git a/vendor/github.com/robfig/cron/v3/option_test.go b/vendor/github.com/robfig/cron/v3/option_test.go new file mode 100644 index 00000000..8aef1682 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/option_test.go @@ -0,0 +1,42 @@ +package cron + +import ( + "log" + "strings" + "testing" + "time" +) + +func TestWithLocation(t *testing.T) { + c := New(WithLocation(time.UTC)) + if c.location != time.UTC { + t.Errorf("expected UTC, got %v", c.location) + } +} + +func TestWithParser(t *testing.T) { + var parser = NewParser(Dow) + c := New(WithParser(parser)) + if c.parser != parser { + t.Error("expected provided parser") + } +} + +func TestWithVerboseLogger(t *testing.T) { + var buf syncWriter + var logger = log.New(&buf, "", log.LstdFlags) + c := New(WithLogger(VerbosePrintfLogger(logger))) + if c.logger.(printfLogger).logger != logger { + t.Error("expected provided logger") + } + + c.AddFunc("@every 1s", func() {}) + c.Start() + time.Sleep(OneSecond) + c.Stop() + out := buf.String() + if !strings.Contains(out, "schedule,") || + !strings.Contains(out, "run,") { + t.Error("expected to see some actions, got:", out) + } +} diff --git a/vendor/github.com/robfig/cron/v3/parser.go b/vendor/github.com/robfig/cron/v3/parser.go new file mode 100644 index 00000000..8da6547a --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/parser.go @@ -0,0 +1,434 @@ +package cron + +import ( + "fmt" + "math" + "strconv" + "strings" + "time" +) + +// Configuration options for creating a parser. Most options specify which +// fields should be included, while others enable features. If a field is not +// included the parser will assume a default value. These options do not change +// the order fields are parse in. +type ParseOption int + +const ( + Second ParseOption = 1 << iota // Seconds field, default 0 + SecondOptional // Optional seconds field, default 0 + Minute // Minutes field, default 0 + Hour // Hours field, default 0 + Dom // Day of month field, default * + Month // Month field, default * + Dow // Day of week field, default * + DowOptional // Optional day of week field, default * + Descriptor // Allow descriptors such as @monthly, @weekly, etc. +) + +var places = []ParseOption{ + Second, + Minute, + Hour, + Dom, + Month, + Dow, +} + +var defaults = []string{ + "0", + "0", + "0", + "*", + "*", + "*", +} + +// A custom Parser that can be configured. +type Parser struct { + options ParseOption +} + +// NewParser creates a Parser with custom options. +// +// It panics if more than one Optional is given, since it would be impossible to +// correctly infer which optional is provided or missing in general. +// +// Examples +// +// // Standard parser without descriptors +// specParser := NewParser(Minute | Hour | Dom | Month | Dow) +// sched, err := specParser.Parse("0 0 15 */3 *") +// +// // Same as above, just excludes time fields +// specParser := NewParser(Dom | Month | Dow) +// sched, err := specParser.Parse("15 */3 *") +// +// // Same as above, just makes Dow optional +// specParser := NewParser(Dom | Month | DowOptional) +// sched, err := specParser.Parse("15 */3") +// +func NewParser(options ParseOption) Parser { + optionals := 0 + if options&DowOptional > 0 { + optionals++ + } + if options&SecondOptional > 0 { + optionals++ + } + if optionals > 1 { + panic("multiple optionals may not be configured") + } + return Parser{options} +} + +// Parse returns a new crontab schedule representing the given spec. +// It returns a descriptive error if the spec is not valid. +// It accepts crontab specs and features configured by NewParser. +func (p Parser) Parse(spec string) (Schedule, error) { + if len(spec) == 0 { + return nil, fmt.Errorf("empty spec string") + } + + // Extract timezone if present + var loc = time.Local + if strings.HasPrefix(spec, "TZ=") || strings.HasPrefix(spec, "CRON_TZ=") { + var err error + i := strings.Index(spec, " ") + eq := strings.Index(spec, "=") + if loc, err = time.LoadLocation(spec[eq+1 : i]); err != nil { + return nil, fmt.Errorf("provided bad location %s: %v", spec[eq+1:i], err) + } + spec = strings.TrimSpace(spec[i:]) + } + + // Handle named schedules (descriptors), if configured + if strings.HasPrefix(spec, "@") { + if p.options&Descriptor == 0 { + return nil, fmt.Errorf("parser does not accept descriptors: %v", spec) + } + return parseDescriptor(spec, loc) + } + + // Split on whitespace. + fields := strings.Fields(spec) + + // Validate & fill in any omitted or optional fields + var err error + fields, err = normalizeFields(fields, p.options) + if err != nil { + return nil, err + } + + field := func(field string, r bounds) uint64 { + if err != nil { + return 0 + } + var bits uint64 + bits, err = getField(field, r) + return bits + } + + var ( + second = field(fields[0], seconds) + minute = field(fields[1], minutes) + hour = field(fields[2], hours) + dayofmonth = field(fields[3], dom) + month = field(fields[4], months) + dayofweek = field(fields[5], dow) + ) + if err != nil { + return nil, err + } + + return &SpecSchedule{ + Second: second, + Minute: minute, + Hour: hour, + Dom: dayofmonth, + Month: month, + Dow: dayofweek, + Location: loc, + }, nil +} + +// normalizeFields takes a subset set of the time fields and returns the full set +// with defaults (zeroes) populated for unset fields. +// +// As part of performing this function, it also validates that the provided +// fields are compatible with the configured options. +func normalizeFields(fields []string, options ParseOption) ([]string, error) { + // Validate optionals & add their field to options + optionals := 0 + if options&SecondOptional > 0 { + options |= Second + optionals++ + } + if options&DowOptional > 0 { + options |= Dow + optionals++ + } + if optionals > 1 { + return nil, fmt.Errorf("multiple optionals may not be configured") + } + + // Figure out how many fields we need + max := 0 + for _, place := range places { + if options&place > 0 { + max++ + } + } + min := max - optionals + + // Validate number of fields + if count := len(fields); count < min || count > max { + if min == max { + return nil, fmt.Errorf("expected exactly %d fields, found %d: %s", min, count, fields) + } + return nil, fmt.Errorf("expected %d to %d fields, found %d: %s", min, max, count, fields) + } + + // Populate the optional field if not provided + if min < max && len(fields) == min { + switch { + case options&DowOptional > 0: + fields = append(fields, defaults[5]) // TODO: improve access to default + case options&SecondOptional > 0: + fields = append([]string{defaults[0]}, fields...) + default: + return nil, fmt.Errorf("unknown optional field") + } + } + + // Populate all fields not part of options with their defaults + n := 0 + expandedFields := make([]string, len(places)) + copy(expandedFields, defaults) + for i, place := range places { + if options&place > 0 { + expandedFields[i] = fields[n] + n++ + } + } + return expandedFields, nil +} + +var standardParser = NewParser( + Minute | Hour | Dom | Month | Dow | Descriptor, +) + +// ParseStandard returns a new crontab schedule representing the given +// standardSpec (https://en.wikipedia.org/wiki/Cron). It requires 5 entries +// representing: minute, hour, day of month, month and day of week, in that +// order. It returns a descriptive error if the spec is not valid. +// +// It accepts +// - Standard crontab specs, e.g. "* * * * ?" +// - Descriptors, e.g. "@midnight", "@every 1h30m" +func ParseStandard(standardSpec string) (Schedule, error) { + return standardParser.Parse(standardSpec) +} + +// getField returns an Int with the bits set representing all of the times that +// the field represents or error parsing field value. A "field" is a comma-separated +// list of "ranges". +func getField(field string, r bounds) (uint64, error) { + var bits uint64 + ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' }) + for _, expr := range ranges { + bit, err := getRange(expr, r) + if err != nil { + return bits, err + } + bits |= bit + } + return bits, nil +} + +// getRange returns the bits indicated by the given expression: +// number | number "-" number [ "/" number ] +// or error parsing range. +func getRange(expr string, r bounds) (uint64, error) { + var ( + start, end, step uint + rangeAndStep = strings.Split(expr, "/") + lowAndHigh = strings.Split(rangeAndStep[0], "-") + singleDigit = len(lowAndHigh) == 1 + err error + ) + + var extra uint64 + if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" { + start = r.min + end = r.max + extra = starBit + } else { + start, err = parseIntOrName(lowAndHigh[0], r.names) + if err != nil { + return 0, err + } + switch len(lowAndHigh) { + case 1: + end = start + case 2: + end, err = parseIntOrName(lowAndHigh[1], r.names) + if err != nil { + return 0, err + } + default: + return 0, fmt.Errorf("too many hyphens: %s", expr) + } + } + + switch len(rangeAndStep) { + case 1: + step = 1 + case 2: + step, err = mustParseInt(rangeAndStep[1]) + if err != nil { + return 0, err + } + + // Special handling: "N/step" means "N-max/step". + if singleDigit { + end = r.max + } + if step > 1 { + extra = 0 + } + default: + return 0, fmt.Errorf("too many slashes: %s", expr) + } + + if start < r.min { + return 0, fmt.Errorf("beginning of range (%d) below minimum (%d): %s", start, r.min, expr) + } + if end > r.max { + return 0, fmt.Errorf("end of range (%d) above maximum (%d): %s", end, r.max, expr) + } + if start > end { + return 0, fmt.Errorf("beginning of range (%d) beyond end of range (%d): %s", start, end, expr) + } + if step == 0 { + return 0, fmt.Errorf("step of range should be a positive number: %s", expr) + } + + return getBits(start, end, step) | extra, nil +} + +// parseIntOrName returns the (possibly-named) integer contained in expr. +func parseIntOrName(expr string, names map[string]uint) (uint, error) { + if names != nil { + if namedInt, ok := names[strings.ToLower(expr)]; ok { + return namedInt, nil + } + } + return mustParseInt(expr) +} + +// mustParseInt parses the given expression as an int or returns an error. +func mustParseInt(expr string) (uint, error) { + num, err := strconv.Atoi(expr) + if err != nil { + return 0, fmt.Errorf("failed to parse int from %s: %s", expr, err) + } + if num < 0 { + return 0, fmt.Errorf("negative number (%d) not allowed: %s", num, expr) + } + + return uint(num), nil +} + +// getBits sets all bits in the range [min, max], modulo the given step size. +func getBits(min, max, step uint) uint64 { + var bits uint64 + + // If step is 1, use shifts. + if step == 1 { + return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min) + } + + // Else, use a simple loop. + for i := min; i <= max; i += step { + bits |= 1 << i + } + return bits +} + +// all returns all bits within the given bounds. (plus the star bit) +func all(r bounds) uint64 { + return getBits(r.min, r.max, 1) | starBit +} + +// parseDescriptor returns a predefined schedule for the expression, or error if none matches. +func parseDescriptor(descriptor string, loc *time.Location) (Schedule, error) { + switch descriptor { + case "@yearly", "@annually": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: 1 << dom.min, + Month: 1 << months.min, + Dow: all(dow), + Location: loc, + }, nil + + case "@monthly": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: 1 << dom.min, + Month: all(months), + Dow: all(dow), + Location: loc, + }, nil + + case "@weekly": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: all(dom), + Month: all(months), + Dow: 1 << dow.min, + Location: loc, + }, nil + + case "@daily", "@midnight": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: all(dom), + Month: all(months), + Dow: all(dow), + Location: loc, + }, nil + + case "@hourly": + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: all(hours), + Dom: all(dom), + Month: all(months), + Dow: all(dow), + Location: loc, + }, nil + + } + + const every = "@every " + if strings.HasPrefix(descriptor, every) { + duration, err := time.ParseDuration(descriptor[len(every):]) + if err != nil { + return nil, fmt.Errorf("failed to parse duration %s: %s", descriptor, err) + } + return Every(duration), nil + } + + return nil, fmt.Errorf("unrecognized descriptor: %s", descriptor) +} diff --git a/vendor/github.com/robfig/cron/v3/parser_test.go b/vendor/github.com/robfig/cron/v3/parser_test.go new file mode 100644 index 00000000..41c8c520 --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/parser_test.go @@ -0,0 +1,383 @@ +package cron + +import ( + "reflect" + "strings" + "testing" + "time" +) + +var secondParser = NewParser(Second | Minute | Hour | Dom | Month | DowOptional | Descriptor) + +func TestRange(t *testing.T) { + zero := uint64(0) + ranges := []struct { + expr string + min, max uint + expected uint64 + err string + }{ + {"5", 0, 7, 1 << 5, ""}, + {"0", 0, 7, 1 << 0, ""}, + {"7", 0, 7, 1 << 7, ""}, + + {"5-5", 0, 7, 1 << 5, ""}, + {"5-6", 0, 7, 1<<5 | 1<<6, ""}, + {"5-7", 0, 7, 1<<5 | 1<<6 | 1<<7, ""}, + + {"5-6/2", 0, 7, 1 << 5, ""}, + {"5-7/2", 0, 7, 1<<5 | 1<<7, ""}, + {"5-7/1", 0, 7, 1<<5 | 1<<6 | 1<<7, ""}, + + {"*", 1, 3, 1<<1 | 1<<2 | 1<<3 | starBit, ""}, + {"*/2", 1, 3, 1<<1 | 1<<3, ""}, + + {"5--5", 0, 0, zero, "too many hyphens"}, + {"jan-x", 0, 0, zero, "failed to parse int from"}, + {"2-x", 1, 5, zero, "failed to parse int from"}, + {"*/-12", 0, 0, zero, "negative number"}, + {"*//2", 0, 0, zero, "too many slashes"}, + {"1", 3, 5, zero, "below minimum"}, + {"6", 3, 5, zero, "above maximum"}, + {"5-3", 3, 5, zero, "beyond end of range"}, + {"*/0", 0, 0, zero, "should be a positive number"}, + } + + for _, c := range ranges { + actual, err := getRange(c.expr, bounds{c.min, c.max, nil}) + if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) { + t.Errorf("%s => expected %v, got %v", c.expr, c.err, err) + } + if len(c.err) == 0 && err != nil { + t.Errorf("%s => unexpected error %v", c.expr, err) + } + if actual != c.expected { + t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual) + } + } +} + +func TestField(t *testing.T) { + fields := []struct { + expr string + min, max uint + expected uint64 + }{ + {"5", 1, 7, 1 << 5}, + {"5,6", 1, 7, 1<<5 | 1<<6}, + {"5,6,7", 1, 7, 1<<5 | 1<<6 | 1<<7}, + {"1,5-7/2,3", 1, 7, 1<<1 | 1<<5 | 1<<7 | 1<<3}, + } + + for _, c := range fields { + actual, _ := getField(c.expr, bounds{c.min, c.max, nil}) + if actual != c.expected { + t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual) + } + } +} + +func TestAll(t *testing.T) { + allBits := []struct { + r bounds + expected uint64 + }{ + {minutes, 0xfffffffffffffff}, // 0-59: 60 ones + {hours, 0xffffff}, // 0-23: 24 ones + {dom, 0xfffffffe}, // 1-31: 31 ones, 1 zero + {months, 0x1ffe}, // 1-12: 12 ones, 1 zero + {dow, 0x7f}, // 0-6: 7 ones + } + + for _, c := range allBits { + actual := all(c.r) // all() adds the starBit, so compensate for that.. + if c.expected|starBit != actual { + t.Errorf("%d-%d/%d => expected %b, got %b", + c.r.min, c.r.max, 1, c.expected|starBit, actual) + } + } +} + +func TestBits(t *testing.T) { + bits := []struct { + min, max, step uint + expected uint64 + }{ + {0, 0, 1, 0x1}, + {1, 1, 1, 0x2}, + {1, 5, 2, 0x2a}, // 101010 + {1, 4, 2, 0xa}, // 1010 + } + + for _, c := range bits { + actual := getBits(c.min, c.max, c.step) + if c.expected != actual { + t.Errorf("%d-%d/%d => expected %b, got %b", + c.min, c.max, c.step, c.expected, actual) + } + } +} + +func TestParseScheduleErrors(t *testing.T) { + var tests = []struct{ expr, err string }{ + {"* 5 j * * *", "failed to parse int from"}, + {"@every Xm", "failed to parse duration"}, + {"@unrecognized", "unrecognized descriptor"}, + {"* * * *", "expected 5 to 6 fields"}, + {"", "empty spec string"}, + } + for _, c := range tests { + actual, err := secondParser.Parse(c.expr) + if err == nil || !strings.Contains(err.Error(), c.err) { + t.Errorf("%s => expected %v, got %v", c.expr, c.err, err) + } + if actual != nil { + t.Errorf("expected nil schedule on error, got %v", actual) + } + } +} + +func TestParseSchedule(t *testing.T) { + tokyo, _ := time.LoadLocation("Asia/Tokyo") + entries := []struct { + parser Parser + expr string + expected Schedule + }{ + {secondParser, "0 5 * * * *", every5min(time.Local)}, + {standardParser, "5 * * * *", every5min(time.Local)}, + {secondParser, "CRON_TZ=UTC 0 5 * * * *", every5min(time.UTC)}, + {standardParser, "CRON_TZ=UTC 5 * * * *", every5min(time.UTC)}, + {secondParser, "CRON_TZ=Asia/Tokyo 0 5 * * * *", every5min(tokyo)}, + {secondParser, "@every 5m", ConstantDelaySchedule{5 * time.Minute}}, + {secondParser, "@midnight", midnight(time.Local)}, + {secondParser, "TZ=UTC @midnight", midnight(time.UTC)}, + {secondParser, "TZ=Asia/Tokyo @midnight", midnight(tokyo)}, + {secondParser, "@yearly", annual(time.Local)}, + {secondParser, "@annually", annual(time.Local)}, + { + parser: secondParser, + expr: "* 5 * * * *", + expected: &SpecSchedule{ + Second: all(seconds), + Minute: 1 << 5, + Hour: all(hours), + Dom: all(dom), + Month: all(months), + Dow: all(dow), + Location: time.Local, + }, + }, + } + + for _, c := range entries { + actual, err := c.parser.Parse(c.expr) + if err != nil { + t.Errorf("%s => unexpected error %v", c.expr, err) + } + if !reflect.DeepEqual(actual, c.expected) { + t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual) + } + } +} + +func TestOptionalSecondSchedule(t *testing.T) { + parser := NewParser(SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor) + entries := []struct { + expr string + expected Schedule + }{ + {"0 5 * * * *", every5min(time.Local)}, + {"5 5 * * * *", every5min5s(time.Local)}, + {"5 * * * *", every5min(time.Local)}, + } + + for _, c := range entries { + actual, err := parser.Parse(c.expr) + if err != nil { + t.Errorf("%s => unexpected error %v", c.expr, err) + } + if !reflect.DeepEqual(actual, c.expected) { + t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual) + } + } +} + +func TestNormalizeFields(t *testing.T) { + tests := []struct { + name string + input []string + options ParseOption + expected []string + }{ + { + "AllFields_NoOptional", + []string{"0", "5", "*", "*", "*", "*"}, + Second | Minute | Hour | Dom | Month | Dow | Descriptor, + []string{"0", "5", "*", "*", "*", "*"}, + }, + { + "AllFields_SecondOptional_Provided", + []string{"0", "5", "*", "*", "*", "*"}, + SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor, + []string{"0", "5", "*", "*", "*", "*"}, + }, + { + "AllFields_SecondOptional_NotProvided", + []string{"5", "*", "*", "*", "*"}, + SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor, + []string{"0", "5", "*", "*", "*", "*"}, + }, + { + "SubsetFields_NoOptional", + []string{"5", "15", "*"}, + Hour | Dom | Month, + []string{"0", "0", "5", "15", "*", "*"}, + }, + { + "SubsetFields_DowOptional_Provided", + []string{"5", "15", "*", "4"}, + Hour | Dom | Month | DowOptional, + []string{"0", "0", "5", "15", "*", "4"}, + }, + { + "SubsetFields_DowOptional_NotProvided", + []string{"5", "15", "*"}, + Hour | Dom | Month | DowOptional, + []string{"0", "0", "5", "15", "*", "*"}, + }, + { + "SubsetFields_SecondOptional_NotProvided", + []string{"5", "15", "*"}, + SecondOptional | Hour | Dom | Month, + []string{"0", "0", "5", "15", "*", "*"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := normalizeFields(test.input, test.options) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !reflect.DeepEqual(actual, test.expected) { + t.Errorf("expected %v, got %v", test.expected, actual) + } + }) + } +} + +func TestNormalizeFields_Errors(t *testing.T) { + tests := []struct { + name string + input []string + options ParseOption + err string + }{ + { + "TwoOptionals", + []string{"0", "5", "*", "*", "*", "*"}, + SecondOptional | Minute | Hour | Dom | Month | DowOptional, + "", + }, + { + "TooManyFields", + []string{"0", "5", "*", "*"}, + SecondOptional | Minute | Hour, + "", + }, + { + "NoFields", + []string{}, + SecondOptional | Minute | Hour, + "", + }, + { + "TooFewFields", + []string{"*"}, + SecondOptional | Minute | Hour, + "", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := normalizeFields(test.input, test.options) + if err == nil { + t.Errorf("expected an error, got none. results: %v", actual) + } + if !strings.Contains(err.Error(), test.err) { + t.Errorf("expected error %q, got %q", test.err, err.Error()) + } + }) + } +} + +func TestStandardSpecSchedule(t *testing.T) { + entries := []struct { + expr string + expected Schedule + err string + }{ + { + expr: "5 * * * *", + expected: &SpecSchedule{1 << seconds.min, 1 << 5, all(hours), all(dom), all(months), all(dow), time.Local}, + }, + { + expr: "@every 5m", + expected: ConstantDelaySchedule{time.Duration(5) * time.Minute}, + }, + { + expr: "5 j * * *", + err: "failed to parse int from", + }, + { + expr: "* * * *", + err: "expected exactly 5 fields", + }, + } + + for _, c := range entries { + actual, err := ParseStandard(c.expr) + if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) { + t.Errorf("%s => expected %v, got %v", c.expr, c.err, err) + } + if len(c.err) == 0 && err != nil { + t.Errorf("%s => unexpected error %v", c.expr, err) + } + if !reflect.DeepEqual(actual, c.expected) { + t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual) + } + } +} + +func TestNoDescriptorParser(t *testing.T) { + parser := NewParser(Minute | Hour) + _, err := parser.Parse("@every 1m") + if err == nil { + t.Error("expected an error, got none") + } +} + +func every5min(loc *time.Location) *SpecSchedule { + return &SpecSchedule{1 << 0, 1 << 5, all(hours), all(dom), all(months), all(dow), loc} +} + +func every5min5s(loc *time.Location) *SpecSchedule { + return &SpecSchedule{1 << 5, 1 << 5, all(hours), all(dom), all(months), all(dow), loc} +} + +func midnight(loc *time.Location) *SpecSchedule { + return &SpecSchedule{1, 1, 1, all(dom), all(months), all(dow), loc} +} + +func annual(loc *time.Location) *SpecSchedule { + return &SpecSchedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Dom: 1 << dom.min, + Month: 1 << months.min, + Dow: all(dow), + Location: loc, + } +} diff --git a/vendor/github.com/robfig/cron/v3/spec.go b/vendor/github.com/robfig/cron/v3/spec.go new file mode 100644 index 00000000..fa1e241e --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/spec.go @@ -0,0 +1,188 @@ +package cron + +import "time" + +// SpecSchedule specifies a duty cycle (to the second granularity), based on a +// traditional crontab specification. It is computed initially and stored as bit sets. +type SpecSchedule struct { + Second, Minute, Hour, Dom, Month, Dow uint64 + + // Override location for this schedule. + Location *time.Location +} + +// bounds provides a range of acceptable values (plus a map of name to value). +type bounds struct { + min, max uint + names map[string]uint +} + +// The bounds for each field. +var ( + seconds = bounds{0, 59, nil} + minutes = bounds{0, 59, nil} + hours = bounds{0, 23, nil} + dom = bounds{1, 31, nil} + months = bounds{1, 12, map[string]uint{ + "jan": 1, + "feb": 2, + "mar": 3, + "apr": 4, + "may": 5, + "jun": 6, + "jul": 7, + "aug": 8, + "sep": 9, + "oct": 10, + "nov": 11, + "dec": 12, + }} + dow = bounds{0, 6, map[string]uint{ + "sun": 0, + "mon": 1, + "tue": 2, + "wed": 3, + "thu": 4, + "fri": 5, + "sat": 6, + }} +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Next returns the next time this schedule is activated, greater than the given +// time. If no time can be found to satisfy the schedule, return the zero time. +func (s *SpecSchedule) Next(t time.Time) time.Time { + // General approach + // + // For Month, Day, Hour, Minute, Second: + // Check if the time value matches. If yes, continue to the next field. + // If the field doesn't match the schedule, then increment the field until it matches. + // While incrementing the field, a wrap-around brings it back to the beginning + // of the field list (since it is necessary to re-verify previous field + // values) + + // Convert the given time into the schedule's timezone, if one is specified. + // Save the original timezone so we can convert back after we find a time. + // Note that schedules without a time zone specified (time.Local) are treated + // as local to the time provided. + origLocation := t.Location() + loc := s.Location + if loc == time.Local { + loc = t.Location() + } + if s.Location != time.Local { + t = t.In(s.Location) + } + + // Start at the earliest possible time (the upcoming second). + t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) + + // This flag indicates whether a field has been incremented. + added := false + + // If no time is found within five years, return zero. + yearLimit := t.Year() + 5 + +WRAP: + if t.Year() > yearLimit { + return time.Time{} + } + + // Find the first applicable month. + // If it's this month, then do nothing. + for 1< 12 { + t = t.Add(time.Duration(24-t.Hour()) * time.Hour) + } else { + t = t.Add(time.Duration(-t.Hour()) * time.Hour) + } + } + + if t.Day() == 1 { + goto WRAP + } + } + + for 1< 0 + dowMatch bool = 1< 0 + ) + if s.Dom&starBit > 0 || s.Dow&starBit > 0 { + return domMatch && dowMatch + } + return domMatch || dowMatch +} diff --git a/vendor/github.com/robfig/cron/v3/spec_test.go b/vendor/github.com/robfig/cron/v3/spec_test.go new file mode 100644 index 00000000..1b8a503e --- /dev/null +++ b/vendor/github.com/robfig/cron/v3/spec_test.go @@ -0,0 +1,300 @@ +package cron + +import ( + "strings" + "testing" + "time" +) + +func TestActivation(t *testing.T) { + tests := []struct { + time, spec string + expected bool + }{ + // Every fifteen minutes. + {"Mon Jul 9 15:00 2012", "0/15 * * * *", true}, + {"Mon Jul 9 15:45 2012", "0/15 * * * *", true}, + {"Mon Jul 9 15:40 2012", "0/15 * * * *", false}, + + // Every fifteen minutes, starting at 5 minutes. + {"Mon Jul 9 15:05 2012", "5/15 * * * *", true}, + {"Mon Jul 9 15:20 2012", "5/15 * * * *", true}, + {"Mon Jul 9 15:50 2012", "5/15 * * * *", true}, + + // Named months + {"Sun Jul 15 15:00 2012", "0/15 * * Jul *", true}, + {"Sun Jul 15 15:00 2012", "0/15 * * Jun *", false}, + + // Everything set. + {"Sun Jul 15 08:30 2012", "30 08 ? Jul Sun", true}, + {"Sun Jul 15 08:30 2012", "30 08 15 Jul ?", true}, + {"Mon Jul 16 08:30 2012", "30 08 ? Jul Sun", false}, + {"Mon Jul 16 08:30 2012", "30 08 15 Jul ?", false}, + + // Predefined schedules + {"Mon Jul 9 15:00 2012", "@hourly", true}, + {"Mon Jul 9 15:04 2012", "@hourly", false}, + {"Mon Jul 9 15:00 2012", "@daily", false}, + {"Mon Jul 9 00:00 2012", "@daily", true}, + {"Mon Jul 9 00:00 2012", "@weekly", false}, + {"Sun Jul 8 00:00 2012", "@weekly", true}, + {"Sun Jul 8 01:00 2012", "@weekly", false}, + {"Sun Jul 8 00:00 2012", "@monthly", false}, + {"Sun Jul 1 00:00 2012", "@monthly", true}, + + // Test interaction of DOW and DOM. + // If both are restricted, then only one needs to match. + {"Sun Jul 15 00:00 2012", "* * 1,15 * Sun", true}, + {"Fri Jun 15 00:00 2012", "* * 1,15 * Sun", true}, + {"Wed Aug 1 00:00 2012", "* * 1,15 * Sun", true}, + {"Sun Jul 15 00:00 2012", "* * */10 * Sun", true}, // verifies #70 + + // However, if one has a star, then both need to match. + {"Sun Jul 15 00:00 2012", "* * * * Mon", false}, + {"Mon Jul 9 00:00 2012", "* * 1,15 * *", false}, + {"Sun Jul 15 00:00 2012", "* * 1,15 * *", true}, + {"Sun Jul 15 00:00 2012", "* * */2 * Sun", true}, + } + + for _, test := range tests { + sched, err := ParseStandard(test.spec) + if err != nil { + t.Error(err) + continue + } + actual := sched.Next(getTime(test.time).Add(-1 * time.Second)) + expected := getTime(test.time) + if test.expected && expected != actual || !test.expected && expected == actual { + t.Errorf("Fail evaluating %s on %s: (expected) %s != %s (actual)", + test.spec, test.time, expected, actual) + } + } +} + +func TestNext(t *testing.T) { + runs := []struct { + time, spec string + expected string + }{ + // Simple cases + {"Mon Jul 9 14:45 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"}, + {"Mon Jul 9 14:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"}, + {"Mon Jul 9 14:59:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"}, + + // Wrap around hours + {"Mon Jul 9 15:45 2012", "0 20-35/15 * * * *", "Mon Jul 9 16:20 2012"}, + + // Wrap around days + {"Mon Jul 9 23:46 2012", "0 */15 * * * *", "Tue Jul 10 00:00 2012"}, + {"Mon Jul 9 23:45 2012", "0 20-35/15 * * * *", "Tue Jul 10 00:20 2012"}, + {"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * * * *", "Tue Jul 10 00:20:15 2012"}, + {"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 * * *", "Tue Jul 10 01:20:15 2012"}, + {"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 10-12 * * *", "Tue Jul 10 10:20:15 2012"}, + + {"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 */2 * *", "Thu Jul 11 01:20:15 2012"}, + {"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 * *", "Wed Jul 10 00:20:15 2012"}, + {"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 Jul *", "Wed Jul 10 00:20:15 2012"}, + + // Wrap around months + {"Mon Jul 9 23:35 2012", "0 0 0 9 Apr-Oct ?", "Thu Aug 9 00:00 2012"}, + {"Mon Jul 9 23:35 2012", "0 0 0 */5 Apr,Aug,Oct Mon", "Tue Aug 1 00:00 2012"}, + {"Mon Jul 9 23:35 2012", "0 0 0 */5 Oct Mon", "Mon Oct 1 00:00 2012"}, + + // Wrap around years + {"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon", "Mon Feb 4 00:00 2013"}, + {"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon/2", "Fri Feb 1 00:00 2013"}, + + // Wrap around minute, hour, day, month, and year + {"Mon Dec 31 23:59:45 2012", "0 * * * * *", "Tue Jan 1 00:00:00 2013"}, + + // Leap year + {"Mon Jul 9 23:35 2012", "0 0 0 29 Feb ?", "Mon Feb 29 00:00 2016"}, + + // Daylight savings time 2am EST (-5) -> 3am EDT (-4) + {"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 30 2 11 Mar ?", "2013-03-11T02:30:00-0400"}, + + // hourly job + {"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"}, + {"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"}, + {"2012-03-11T03:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"}, + {"2012-03-11T04:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"}, + + // hourly job using CRON_TZ + {"2012-03-11T00:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"}, + {"2012-03-11T01:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"}, + {"2012-03-11T03:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"}, + {"2012-03-11T04:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"}, + + // 1am nightly job + {"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-11T01:00:00-0500"}, + {"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-12T01:00:00-0400"}, + + // 2am nightly job (skipped) + {"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-03-12T02:00:00-0400"}, + + // Daylight savings time 2am EDT (-4) => 1am EST (-5) + {"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 30 2 04 Nov ?", "2012-11-04T02:30:00-0500"}, + {"2012-11-04T01:45:00-0400", "TZ=America/New_York 0 30 1 04 Nov ?", "2012-11-04T01:30:00-0500"}, + + // hourly job + {"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0400"}, + {"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0500"}, + {"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T02:00:00-0500"}, + + // 1am nightly job (runs twice) + {"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0400"}, + {"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0500"}, + {"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-11-05T01:00:00-0500"}, + + // 2am nightly job + {"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 2 * * ?", "2012-11-04T02:00:00-0500"}, + {"2012-11-04T02:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-11-05T02:00:00-0500"}, + + // 3am nightly job + {"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 3 * * ?", "2012-11-04T03:00:00-0500"}, + {"2012-11-04T03:00:00-0500", "TZ=America/New_York 0 0 3 * * ?", "2012-11-05T03:00:00-0500"}, + + // hourly job + {"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0400"}, + {"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0500"}, + {"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 * * * ?", "2012-11-04T02:00:00-0500"}, + + // 1am nightly job (runs twice) + {"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0400"}, + {"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0500"}, + {"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 1 * * ?", "2012-11-05T01:00:00-0500"}, + + // 2am nightly job + {"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 2 * * ?", "2012-11-04T02:00:00-0500"}, + {"TZ=America/New_York 2012-11-04T02:00:00-0500", "0 0 2 * * ?", "2012-11-05T02:00:00-0500"}, + + // 3am nightly job + {"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 * * ?", "2012-11-04T03:00:00-0500"}, + {"TZ=America/New_York 2012-11-04T03:00:00-0500", "0 0 3 * * ?", "2012-11-05T03:00:00-0500"}, + + // Unsatisfiable + {"Mon Jul 9 23:35 2012", "0 0 0 30 Feb ?", ""}, + {"Mon Jul 9 23:35 2012", "0 0 0 31 Apr ?", ""}, + + // Monthly job + {"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 3 * ?", "2012-12-03T03:00:00-0500"}, + + // Test the scenario of DST resulting in midnight not being a valid time. + // https://github.com/robfig/cron/issues/157 + {"2018-10-17T05:00:00-0400", "TZ=America/Sao_Paulo 0 0 9 10 * ?", "2018-11-10T06:00:00-0500"}, + {"2018-02-14T05:00:00-0500", "TZ=America/Sao_Paulo 0 0 9 22 * ?", "2018-02-22T07:00:00-0500"}, + } + + for _, c := range runs { + sched, err := secondParser.Parse(c.spec) + if err != nil { + t.Error(err) + continue + } + actual := sched.Next(getTime(c.time)) + expected := getTime(c.expected) + if !actual.Equal(expected) { + t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual) + } + } +} + +func TestErrors(t *testing.T) { + invalidSpecs := []string{ + "xyz", + "60 0 * * *", + "0 60 * * *", + "0 0 * * XYZ", + } + for _, spec := range invalidSpecs { + _, err := ParseStandard(spec) + if err == nil { + t.Error("expected an error parsing: ", spec) + } + } +} + +func getTime(value string) time.Time { + if value == "" { + return time.Time{} + } + + var location = time.Local + if strings.HasPrefix(value, "TZ=") { + parts := strings.Fields(value) + loc, err := time.LoadLocation(parts[0][len("TZ="):]) + if err != nil { + panic("could not parse location:" + err.Error()) + } + location = loc + value = parts[1] + } + + var layouts = []string{ + "Mon Jan 2 15:04 2006", + "Mon Jan 2 15:04:05 2006", + } + for _, layout := range layouts { + if t, err := time.ParseInLocation(layout, value, location); err == nil { + return t + } + } + if t, err := time.ParseInLocation("2006-01-02T15:04:05-0700", value, location); err == nil { + return t + } + panic("could not parse time value " + value) +} + +func TestNextWithTz(t *testing.T) { + runs := []struct { + time, spec string + expected string + }{ + // Failing tests + {"2016-01-03T13:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"}, + {"2016-01-03T04:09:03+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"}, + + // Passing tests + {"2016-01-03T14:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"}, + {"2016-01-03T14:00:00+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"}, + } + for _, c := range runs { + sched, err := ParseStandard(c.spec) + if err != nil { + t.Error(err) + continue + } + actual := sched.Next(getTimeTZ(c.time)) + expected := getTimeTZ(c.expected) + if !actual.Equal(expected) { + t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual) + } + } +} + +func getTimeTZ(value string) time.Time { + if value == "" { + return time.Time{} + } + t, err := time.Parse("Mon Jan 2 15:04 2006", value) + if err != nil { + t, err = time.Parse("Mon Jan 2 15:04:05 2006", value) + if err != nil { + t, err = time.Parse("2006-01-02T15:04:05-0700", value) + if err != nil { + panic(err) + } + } + } + + return t +} + +// https://github.com/robfig/cron/issues/144 +func TestSlash0NoHang(t *testing.T) { + schedule := "TZ=America/New_York 15/0 * * * *" + _, err := ParseStandard(schedule) + if err == nil { + t.Error("expected an error on 0 increment") + } +} diff --git a/vendor/github.com/templexxx/cpu/.gitignore b/vendor/github.com/templexxx/cpu/.gitignore new file mode 100644 index 00000000..63c61f36 --- /dev/null +++ b/vendor/github.com/templexxx/cpu/.gitignore @@ -0,0 +1,13 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out +.idea/ diff --git a/vendor/github.com/templexxx/cpu/LICENSE b/vendor/github.com/templexxx/cpu/LICENSE new file mode 100644 index 00000000..dfa8f7b8 --- /dev/null +++ b/vendor/github.com/templexxx/cpu/LICENSE @@ -0,0 +1,32 @@ +BSD 3-Clause License + +Copyright (c) 2018 Temple3x (temple3x@gmail.com) +Copyright 2017 The Go Authors +Copyright (c) 2015 Klaus Post + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/templexxx/cpu/README.md b/vendor/github.com/templexxx/cpu/README.md new file mode 100644 index 00000000..50ccb9fd --- /dev/null +++ b/vendor/github.com/templexxx/cpu/README.md @@ -0,0 +1,23 @@ +# cpu +internal/cpu(in Go standard lib) with these detections: + +>- AVX512 +> +>- Cache Size +> +>- Invariant TSC +> + +It also provides: + +>- False sharing range, see `X86FalseSharingRange` for X86 platform. +> +>- TSC frequency +> +>- Name +> +>- Family & Model + +# Acknowledgement + +[klauspost/cpuid](https://github.com/klauspost/cpuid) \ No newline at end of file diff --git a/vendor/github.com/templexxx/cpu/cpu.go b/vendor/github.com/templexxx/cpu/cpu.go new file mode 100644 index 00000000..767f8b3f --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu.go @@ -0,0 +1,235 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cpu implements processor feature detection +// used by the Go standard library. +package cpu + +// debugOptions is set to true by the runtime if go was compiled with GOEXPERIMENT=debugcpu +// and GOOS is Linux or Darwin. This variable is linknamed in runtime/proc.go. +var debugOptions bool + +var X86 x86 + +// "Loads data or instructions from memory to the second-level cache. +// To use the streamer, organize the data or instructions in blocks of 128 bytes, +// aligned on 128 bytes." +// From , +// in section 3.7.3 "Hardware Prefetching for Second-Level Cache" +// +// In practice, I have found use 128bytes can gain better performance than 64bytes (one cache line). +const X86FalseSharingRange = 128 + +// The booleans in x86 contain the correspondingly named cpuid feature bit. +// HasAVX and HasAVX2 are only set if the OS does support XMM and YMM registers +// in addition to the cpuid feature bit being set. +// The struct is padded to avoid false sharing. +type x86 struct { + _ [X86FalseSharingRange]byte + HasCMPXCHG16B bool + HasAES bool + HasADX bool + HasAVX bool + HasAVX2 bool + HasAVX512F bool + HasAVX512DQ bool + HasAVX512BW bool + HasAVX512VL bool + HasBMI1 bool + HasBMI2 bool + HasERMS bool + HasFMA bool + HasOSXSAVE bool + HasPCLMULQDQ bool + HasPOPCNT bool + HasSSE2 bool + HasSSE3 bool + HasSSSE3 bool + HasSSE41 bool + HasSSE42 bool + // The invariant TSC will run at a constant rate in all ACPI P-, C-, and T-states. + // This is the architectural behavior moving forward. On processors with + // invariant TSC support, the OS may use the TSC for wall clock timer services (instead of ACPI or HPET timers). + HasInvariantTSC bool + + Cache Cache + + // TSCFrequency only meaningful when HasInvariantTSC == true. + // Unit: Hz. + // + // Warn: + // 1. If it's 0, means failed to get it from frequency table provided by Intel manual. + TSCFrequency uint64 + + Name string + Signature string // DisplayFamily_DisplayModel. + Family uint32 // CPU family number. + Model uint32 // CPU model number. + SteppingID uint32 + + _ [X86FalseSharingRange]byte +} + +// CPU Cache Size. +// -1 if undetected. +type Cache struct { + L1I int + L1D int + L2 int + L3 int +} + +var PPC64 ppc64 + +// For ppc64x, it is safe to check only for ISA level starting on ISA v3.00, +// since there are no optional categories. There are some exceptions that also +// require kernel support to work (darn, scv), so there are feature bits for +// those as well. The minimum processor requirement is POWER8 (ISA 2.07), so we +// maintain some of the old feature checks for optional categories for +// safety. +// The struct is padded to avoid false sharing. +type ppc64 struct { + _ [CacheLineSize]byte + HasVMX bool // Vector unit (Altivec) + HasDFP bool // Decimal Floating Point unit + HasVSX bool // Vector-scalar unit + HasHTM bool // Hardware Transactional Memory + HasISEL bool // Integer select + HasVCRYPTO bool // Vector cryptography + HasHTMNOSC bool // HTM: kernel-aborted transaction in syscalls + HasDARN bool // Hardware random number generator (requires kernel enablement) + HasSCV bool // Syscall vectored (requires kernel enablement) + IsPOWER8 bool // ISA v2.07 (POWER8) + IsPOWER9 bool // ISA v3.00 (POWER9) + _ [CacheLineSize]byte +} + +var ARM64 arm64 + +// The booleans in arm64 contain the correspondingly named cpu feature bit. +// The struct is padded to avoid false sharing. +type arm64 struct { + _ [CacheLineSize]byte + HasFP bool + HasASIMD bool + HasEVTSTRM bool + HasAES bool + HasPMULL bool + HasSHA1 bool + HasSHA2 bool + HasCRC32 bool + HasATOMICS bool + HasFPHP bool + HasASIMDHP bool + HasCPUID bool + HasASIMDRDM bool + HasJSCVT bool + HasFCMA bool + HasLRCPC bool + HasDCPOP bool + HasSHA3 bool + HasSM3 bool + HasSM4 bool + HasASIMDDP bool + HasSHA512 bool + HasSVE bool + HasASIMDFHM bool + _ [CacheLineSize]byte +} + +var S390X s390x + +type s390x struct { + _ [CacheLineSize]byte + HasZArch bool // z architecture mode is active [mandatory] + HasSTFLE bool // store facility list extended [mandatory] + HasLDisp bool // long (20-bit) displacements [mandatory] + HasEImm bool // 32-bit immediates [mandatory] + HasDFP bool // decimal floating point + HasETF3Enhanced bool // ETF-3 enhanced + HasMSA bool // message security assist (CPACF) + HasAES bool // KM-AES{128,192,256} functions + HasAESCBC bool // KMC-AES{128,192,256} functions + HasAESCTR bool // KMCTR-AES{128,192,256} functions + HasAESGCM bool // KMA-GCM-AES{128,192,256} functions + HasGHASH bool // KIMD-GHASH function + HasSHA1 bool // K{I,L}MD-SHA-1 functions + HasSHA256 bool // K{I,L}MD-SHA-256 functions + HasSHA512 bool // K{I,L}MD-SHA-512 functions + HasVX bool // vector facility. Note: the runtime sets this when it processes auxv records. + _ [CacheLineSize]byte +} + +// initialize examines the processor and sets the relevant variables above. +// This is called by the runtime package early in program initialization, +// before normal init functions are run. env is set by runtime on Linux and Darwin +// if go was compiled with GOEXPERIMENT=debugcpu. +func init() { + doinit() + processOptions("") +} + +// options contains the cpu debug options that can be used in GODEBUGCPU. +// Options are arch dependent and are added by the arch specific doinit functions. +// Features that are mandatory for the specific GOARCH should not be added to options +// (e.g. SSE2 on amd64). +var options []option + +// Option names should be lower case. e.g. avx instead of AVX. +type option struct { + Name string + Feature *bool +} + +// processOptions disables CPU feature values based on the parsed env string. +// The env string is expected to be of the form feature1=0,feature2=0... +// where feature names is one of the architecture specifc list stored in the +// cpu packages options variable. If env contains all=0 then all capabilities +// referenced through the options variable are disabled. Other feature +// names and values other than 0 are silently ignored. +func processOptions(env string) { +field: + for env != "" { + field := "" + i := indexByte(env, ',') + if i < 0 { + field, env = env, "" + } else { + field, env = env[:i], env[i+1:] + } + i = indexByte(field, '=') + if i < 0 { + continue + } + key, value := field[:i], field[i+1:] + + // Only allow turning off CPU features by specifying '0'. + if value == "0" { + if key == "all" { + for _, v := range options { + *v.Feature = false + } + return + } else { + for _, v := range options { + if v.Name == key { + *v.Feature = false + continue field + } + } + } + } + } +} + +// indexByte returns the index of the first instance of c in s, +// or -1 if c is not present in s. +func indexByte(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} diff --git a/vendor/github.com/templexxx/cpu/cpu_386.go b/vendor/github.com/templexxx/cpu/cpu_386.go new file mode 100644 index 00000000..561c81f8 --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_386.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const GOARCH = "386" diff --git a/vendor/github.com/templexxx/cpu/cpu_amd64.go b/vendor/github.com/templexxx/cpu/cpu_amd64.go new file mode 100644 index 00000000..9b001536 --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_amd64.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const GOARCH = "amd64" diff --git a/vendor/github.com/templexxx/cpu/cpu_amd64p32.go b/vendor/github.com/templexxx/cpu/cpu_amd64p32.go new file mode 100644 index 00000000..177b14e3 --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_amd64p32.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const GOARCH = "amd64p32" diff --git a/vendor/github.com/templexxx/cpu/cpu_arm.go b/vendor/github.com/templexxx/cpu/cpu_arm.go new file mode 100644 index 00000000..078a6c3b --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_arm.go @@ -0,0 +1,7 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 32 diff --git a/vendor/github.com/templexxx/cpu/cpu_arm64.go b/vendor/github.com/templexxx/cpu/cpu_arm64.go new file mode 100644 index 00000000..487ccf8e --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_arm64.go @@ -0,0 +1,102 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 64 + +// arm64 doesn't have a 'cpuid' equivalent, so we rely on HWCAP/HWCAP2. +// These are linknamed in runtime/os_linux_arm64.go and are initialized by +// archauxv(). +var hwcap uint +var hwcap2 uint + +// HWCAP/HWCAP2 bits. These are exposed by Linux. +const ( + hwcap_FP = (1 << 0) + hwcap_ASIMD = (1 << 1) + hwcap_EVTSTRM = (1 << 2) + hwcap_AES = (1 << 3) + hwcap_PMULL = (1 << 4) + hwcap_SHA1 = (1 << 5) + hwcap_SHA2 = (1 << 6) + hwcap_CRC32 = (1 << 7) + hwcap_ATOMICS = (1 << 8) + hwcap_FPHP = (1 << 9) + hwcap_ASIMDHP = (1 << 10) + hwcap_CPUID = (1 << 11) + hwcap_ASIMDRDM = (1 << 12) + hwcap_JSCVT = (1 << 13) + hwcap_FCMA = (1 << 14) + hwcap_LRCPC = (1 << 15) + hwcap_DCPOP = (1 << 16) + hwcap_SHA3 = (1 << 17) + hwcap_SM3 = (1 << 18) + hwcap_SM4 = (1 << 19) + hwcap_ASIMDDP = (1 << 20) + hwcap_SHA512 = (1 << 21) + hwcap_SVE = (1 << 22) + hwcap_ASIMDFHM = (1 << 23) +) + +func doinit() { + options = []option{ + {"evtstrm", &ARM64.HasEVTSTRM}, + {"aes", &ARM64.HasAES}, + {"pmull", &ARM64.HasPMULL}, + {"sha1", &ARM64.HasSHA1}, + {"sha2", &ARM64.HasSHA2}, + {"crc32", &ARM64.HasCRC32}, + {"atomics", &ARM64.HasATOMICS}, + {"fphp", &ARM64.HasFPHP}, + {"asimdhp", &ARM64.HasASIMDHP}, + {"cpuid", &ARM64.HasCPUID}, + {"asimdrdm", &ARM64.HasASIMDRDM}, + {"jscvt", &ARM64.HasJSCVT}, + {"fcma", &ARM64.HasFCMA}, + {"lrcpc", &ARM64.HasLRCPC}, + {"dcpop", &ARM64.HasDCPOP}, + {"sha3", &ARM64.HasSHA3}, + {"sm3", &ARM64.HasSM3}, + {"sm4", &ARM64.HasSM4}, + {"asimddp", &ARM64.HasASIMDDP}, + {"sha512", &ARM64.HasSHA512}, + {"sve", &ARM64.HasSVE}, + {"asimdfhm", &ARM64.HasASIMDFHM}, + + // These capabilities should always be enabled on arm64: + // {"fp", &ARM64.HasFP}, + // {"asimd", &ARM64.HasASIMD}, + } + + // HWCAP feature bits + ARM64.HasFP = isSet(hwcap, hwcap_FP) + ARM64.HasASIMD = isSet(hwcap, hwcap_ASIMD) + ARM64.HasEVTSTRM = isSet(hwcap, hwcap_EVTSTRM) + ARM64.HasAES = isSet(hwcap, hwcap_AES) + ARM64.HasPMULL = isSet(hwcap, hwcap_PMULL) + ARM64.HasSHA1 = isSet(hwcap, hwcap_SHA1) + ARM64.HasSHA2 = isSet(hwcap, hwcap_SHA2) + ARM64.HasCRC32 = isSet(hwcap, hwcap_CRC32) + ARM64.HasATOMICS = isSet(hwcap, hwcap_ATOMICS) + ARM64.HasFPHP = isSet(hwcap, hwcap_FPHP) + ARM64.HasASIMDHP = isSet(hwcap, hwcap_ASIMDHP) + ARM64.HasCPUID = isSet(hwcap, hwcap_CPUID) + ARM64.HasASIMDRDM = isSet(hwcap, hwcap_ASIMDRDM) + ARM64.HasJSCVT = isSet(hwcap, hwcap_JSCVT) + ARM64.HasFCMA = isSet(hwcap, hwcap_FCMA) + ARM64.HasLRCPC = isSet(hwcap, hwcap_LRCPC) + ARM64.HasDCPOP = isSet(hwcap, hwcap_DCPOP) + ARM64.HasSHA3 = isSet(hwcap, hwcap_SHA3) + ARM64.HasSM3 = isSet(hwcap, hwcap_SM3) + ARM64.HasSM4 = isSet(hwcap, hwcap_SM4) + ARM64.HasASIMDDP = isSet(hwcap, hwcap_ASIMDDP) + ARM64.HasSHA512 = isSet(hwcap, hwcap_SHA512) + ARM64.HasSVE = isSet(hwcap, hwcap_SVE) + ARM64.HasASIMDFHM = isSet(hwcap, hwcap_ASIMDFHM) +} + +func isSet(hwc uint, value uint) bool { + return hwc&value != 0 +} diff --git a/vendor/github.com/templexxx/cpu/cpu_mips.go b/vendor/github.com/templexxx/cpu/cpu_mips.go new file mode 100644 index 00000000..078a6c3b --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_mips.go @@ -0,0 +1,7 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 32 diff --git a/vendor/github.com/templexxx/cpu/cpu_mips64.go b/vendor/github.com/templexxx/cpu/cpu_mips64.go new file mode 100644 index 00000000..078a6c3b --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_mips64.go @@ -0,0 +1,7 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 32 diff --git a/vendor/github.com/templexxx/cpu/cpu_mips64le.go b/vendor/github.com/templexxx/cpu/cpu_mips64le.go new file mode 100644 index 00000000..078a6c3b --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_mips64le.go @@ -0,0 +1,7 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 32 diff --git a/vendor/github.com/templexxx/cpu/cpu_mipsle.go b/vendor/github.com/templexxx/cpu/cpu_mipsle.go new file mode 100644 index 00000000..078a6c3b --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_mipsle.go @@ -0,0 +1,7 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 32 diff --git a/vendor/github.com/templexxx/cpu/cpu_no_init.go b/vendor/github.com/templexxx/cpu/cpu_no_init.go new file mode 100644 index 00000000..1be4f29d --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_no_init.go @@ -0,0 +1,16 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !386 +// +build !amd64 +// +build !amd64p32 +// +build !arm64 +// +build !ppc64 +// +build !ppc64le +// +build !s390x + +package cpu + +func doinit() { +} diff --git a/vendor/github.com/templexxx/cpu/cpu_ppc64x.go b/vendor/github.com/templexxx/cpu/cpu_ppc64x.go new file mode 100644 index 00000000..995cf020 --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_ppc64x.go @@ -0,0 +1,68 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ppc64 ppc64le + +package cpu + +const CacheLineSize = 128 + +// ppc64x doesn't have a 'cpuid' equivalent, so we rely on HWCAP/HWCAP2. +// These are linknamed in runtime/os_linux_ppc64x.go and are initialized by +// archauxv(). +var hwcap uint +var hwcap2 uint + +// HWCAP/HWCAP2 bits. These are exposed by the kernel. +const ( + // ISA Level + _PPC_FEATURE2_ARCH_2_07 = 0x80000000 + _PPC_FEATURE2_ARCH_3_00 = 0x00800000 + + // CPU features + _PPC_FEATURE_HAS_ALTIVEC = 0x10000000 + _PPC_FEATURE_HAS_DFP = 0x00000400 + _PPC_FEATURE_HAS_VSX = 0x00000080 + _PPC_FEATURE2_HAS_HTM = 0x40000000 + _PPC_FEATURE2_HAS_ISEL = 0x08000000 + _PPC_FEATURE2_HAS_VEC_CRYPTO = 0x02000000 + _PPC_FEATURE2_HTM_NOSC = 0x01000000 + _PPC_FEATURE2_DARN = 0x00200000 + _PPC_FEATURE2_SCV = 0x00100000 +) + +func doinit() { + options = []option{ + {"htm", &PPC64.HasHTM}, + {"htmnosc", &PPC64.HasHTMNOSC}, + {"darn", &PPC64.HasDARN}, + {"scv", &PPC64.HasSCV}, + + // These capabilities should always be enabled on ppc64 and ppc64le: + // {"vmx", &PPC64.HasVMX}, + // {"dfp", &PPC64.HasDFP}, + // {"vsx", &PPC64.HasVSX}, + // {"isel", &PPC64.HasISEL}, + // {"vcrypto", &PPC64.HasVCRYPTO}, + } + + // HWCAP feature bits + PPC64.HasVMX = isSet(hwcap, _PPC_FEATURE_HAS_ALTIVEC) + PPC64.HasDFP = isSet(hwcap, _PPC_FEATURE_HAS_DFP) + PPC64.HasVSX = isSet(hwcap, _PPC_FEATURE_HAS_VSX) + + // HWCAP2 feature bits + PPC64.IsPOWER8 = isSet(hwcap2, _PPC_FEATURE2_ARCH_2_07) + PPC64.HasHTM = isSet(hwcap2, _PPC_FEATURE2_HAS_HTM) + PPC64.HasISEL = isSet(hwcap2, _PPC_FEATURE2_HAS_ISEL) + PPC64.HasVCRYPTO = isSet(hwcap2, _PPC_FEATURE2_HAS_VEC_CRYPTO) + PPC64.HasHTMNOSC = isSet(hwcap2, _PPC_FEATURE2_HTM_NOSC) + PPC64.IsPOWER9 = isSet(hwcap2, _PPC_FEATURE2_ARCH_3_00) + PPC64.HasDARN = isSet(hwcap2, _PPC_FEATURE2_DARN) + PPC64.HasSCV = isSet(hwcap2, _PPC_FEATURE2_SCV) +} + +func isSet(hwc uint, value uint) bool { + return hwc&value != 0 +} diff --git a/vendor/github.com/templexxx/cpu/cpu_riscv64.go b/vendor/github.com/templexxx/cpu/cpu_riscv64.go new file mode 100644 index 00000000..0b86a4fd --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_riscv64.go @@ -0,0 +1,10 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 32 + +func doinit() { +} diff --git a/vendor/github.com/templexxx/cpu/cpu_s390x.go b/vendor/github.com/templexxx/cpu/cpu_s390x.go new file mode 100644 index 00000000..389a058c --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_s390x.go @@ -0,0 +1,153 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 256 + +// bitIsSet reports whether the bit at index is set. The bit index +// is in big endian order, so bit index 0 is the leftmost bit. +func bitIsSet(bits []uint64, index uint) bool { + return bits[index/64]&((1<<63)>>(index%64)) != 0 +} + +// function is the function code for the named function. +type function uint8 + +const ( + // KM{,A,C,CTR} function codes + aes128 function = 18 // AES-128 + aes192 = 19 // AES-192 + aes256 = 20 // AES-256 + + // K{I,L}MD function codes + sha1 = 1 // SHA-1 + sha256 = 2 // SHA-256 + sha512 = 3 // SHA-512 + + // KLMD function codes + ghash = 65 // GHASH +) + +// queryResult contains the result of a Query function +// call. Bits are numbered in big endian order so the +// leftmost bit (the MSB) is at index 0. +type queryResult struct { + bits [2]uint64 +} + +// Has reports whether the given functions are present. +func (q *queryResult) Has(fns ...function) bool { + if len(fns) == 0 { + panic("no function codes provided") + } + for _, f := range fns { + if !bitIsSet(q.bits[:], uint(f)) { + return false + } + } + return true +} + +// facility is a bit index for the named facility. +type facility uint8 + +const ( + // mandatory facilities + zarch facility = 1 // z architecture mode is active + stflef = 7 // store-facility-list-extended + ldisp = 18 // long-displacement + eimm = 21 // extended-immediate + + // miscellaneous facilities + dfp = 42 // decimal-floating-point + etf3eh = 30 // extended-translation 3 enhancement + + // cryptography facilities + msa = 17 // message-security-assist + msa3 = 76 // message-security-assist extension 3 + msa4 = 77 // message-security-assist extension 4 + msa5 = 57 // message-security-assist extension 5 + msa8 = 146 // message-security-assist extension 8 + + // Note: vx and highgprs are excluded because they require + // kernel support and so must be fetched from HWCAP. +) + +// facilityList contains the result of an STFLE call. +// Bits are numbered in big endian order so the +// leftmost bit (the MSB) is at index 0. +type facilityList struct { + bits [4]uint64 +} + +// Has reports whether the given facilities are present. +func (s *facilityList) Has(fs ...facility) bool { + if len(fs) == 0 { + panic("no facility bits provided") + } + for _, f := range fs { + if !bitIsSet(s.bits[:], uint(f)) { + return false + } + } + return true +} + +// The following feature detection functions are defined in cpu_s390x.s. +// They are likely to be expensive to call so the results should be cached. +func stfle() facilityList +func kmQuery() queryResult +func kmcQuery() queryResult +func kmctrQuery() queryResult +func kmaQuery() queryResult +func kimdQuery() queryResult +func klmdQuery() queryResult + +func doinit() { + options = []option{ + {"zarch", &S390X.HasZArch}, + {"stfle", &S390X.HasSTFLE}, + {"ldisp", &S390X.HasLDisp}, + {"msa", &S390X.HasMSA}, + {"eimm", &S390X.HasEImm}, + {"dfp", &S390X.HasDFP}, + {"etf3eh", &S390X.HasETF3Enhanced}, + {"vx", &S390X.HasVX}, + } + + aes := []function{aes128, aes192, aes256} + facilities := stfle() + + S390X.HasZArch = facilities.Has(zarch) + S390X.HasSTFLE = facilities.Has(stflef) + S390X.HasLDisp = facilities.Has(ldisp) + S390X.HasEImm = facilities.Has(eimm) + S390X.HasDFP = facilities.Has(dfp) + S390X.HasETF3Enhanced = facilities.Has(etf3eh) + S390X.HasMSA = facilities.Has(msa) + + if S390X.HasMSA { + // cipher message + km, kmc := kmQuery(), kmcQuery() + S390X.HasAES = km.Has(aes...) + S390X.HasAESCBC = kmc.Has(aes...) + if facilities.Has(msa4) { + kmctr := kmctrQuery() + S390X.HasAESCTR = kmctr.Has(aes...) + } + if facilities.Has(msa8) { + kma := kmaQuery() + S390X.HasAESGCM = kma.Has(aes...) + } + + // compute message digest + kimd := kimdQuery() // intermediate (no padding) + klmd := klmdQuery() // last (padding) + S390X.HasSHA1 = kimd.Has(sha1) && klmd.Has(sha1) + S390X.HasSHA256 = kimd.Has(sha256) && klmd.Has(sha256) + S390X.HasSHA512 = kimd.Has(sha512) && klmd.Has(sha512) + S390X.HasGHASH = kimd.Has(ghash) // KLMD-GHASH does not exist + } +} diff --git a/vendor/github.com/templexxx/cpu/cpu_s390x.s b/vendor/github.com/templexxx/cpu/cpu_s390x.s new file mode 100644 index 00000000..9678035f --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_s390x.s @@ -0,0 +1,55 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "textflag.h" + +// func stfle() facilityList +TEXT ·stfle(SB), NOSPLIT|NOFRAME, $0-32 + MOVD $ret+0(FP), R1 + MOVD $3, R0 // last doubleword index to store + XC $32, (R1), (R1) // clear 4 doublewords (32 bytes) + WORD $0xb2b01000 // store facility list extended (STFLE) + RET + +// func kmQuery() queryResult +TEXT ·kmQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KM-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB92E0024 // cipher message (KM) + RET + +// func kmcQuery() queryResult +TEXT ·kmcQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KMC-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB92F0024 // cipher message with chaining (KMC) + RET + +// func kmctrQuery() queryResult +TEXT ·kmctrQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KMCTR-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB92D4024 // cipher message with counter (KMCTR) + RET + +// func kmaQuery() queryResult +TEXT ·kmaQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KMA-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xb9296024 // cipher message with authentication (KMA) + RET + +// func kimdQuery() queryResult +TEXT ·kimdQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KIMD-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB93E0024 // compute intermediate message digest (KIMD) + RET + +// func klmdQuery() queryResult +TEXT ·klmdQuery(SB), NOSPLIT|NOFRAME, $0-16 + MOVD $0, R0 // set function code to 0 (KLMD-Query) + MOVD $ret+0(FP), R1 // address of 16-byte return value + WORD $0xB93F0024 // compute last message digest (KLMD) + RET diff --git a/vendor/github.com/templexxx/cpu/cpu_wasm.go b/vendor/github.com/templexxx/cpu/cpu_wasm.go new file mode 100644 index 00000000..1107a7ad --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_wasm.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const CacheLineSize = 64 diff --git a/vendor/github.com/templexxx/cpu/cpu_x86.go b/vendor/github.com/templexxx/cpu/cpu_x86.go new file mode 100644 index 00000000..7297fe2d --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_x86.go @@ -0,0 +1,433 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 + +package cpu + +import ( + "fmt" + "strings" +) + +const CacheLineSize = 64 + +// cpuid is implemented in cpu_x86.s. +func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) + +// xgetbv with ecx = 0 is implemented in cpu_x86.s. +func xgetbv() (eax, edx uint32) + +const ( + // edx bits + cpuid_SSE2 = 1 << 26 + + // ecx bits + cpuid_SSE3 = 1 << 0 + cpuid_PCLMULQDQ = 1 << 1 + cpuid_SSSE3 = 1 << 9 + cpuid_FMA = 1 << 12 + cpuid_SSE41 = 1 << 19 + cpuid_SSE42 = 1 << 20 + cpuid_POPCNT = 1 << 23 + cpuid_AES = 1 << 25 + cpuid_OSXSAVE = 1 << 27 + cpuid_AVX = 1 << 28 + cpuid_CMPXCHG16B = 1 << 13 + + // ebx bits + cpuid_BMI1 = 1 << 3 + cpuid_AVX2 = 1 << 5 + cpuid_BMI2 = 1 << 8 + cpuid_ERMS = 1 << 9 + cpuid_ADX = 1 << 19 + cpuid_AVX512F = 1 << 16 + cpuid_AVX512DQ = 1 << 17 + cpuid_AVX512BW = 1 << 30 + cpuid_AVX512VL = 1 << 31 + + // edx bits + cpuid_Invariant_TSC = 1 << 8 +) + +func doinit() { + options = []option{ + {"adx", &X86.HasADX}, + {"aes", &X86.HasAES}, + {"avx", &X86.HasAVX}, + {"avx2", &X86.HasAVX2}, + {"bmi1", &X86.HasBMI1}, + {"bmi2", &X86.HasBMI2}, + {"erms", &X86.HasERMS}, + {"fma", &X86.HasFMA}, + {"pclmulqdq", &X86.HasPCLMULQDQ}, + {"popcnt", &X86.HasPOPCNT}, + {"sse3", &X86.HasSSE3}, + {"sse41", &X86.HasSSE41}, + {"sse42", &X86.HasSSE42}, + {"ssse3", &X86.HasSSSE3}, + {"avx512f", &X86.HasAVX512F}, + {"avx512dq", &X86.HasAVX512DQ}, + {"avx512bw", &X86.HasAVX512BW}, + {"avx512vl", &X86.HasAVX512VL}, + {"invariant_tsc", &X86.HasInvariantTSC}, + + // sse2 set as last element so it can easily be removed again. See code below. + {"sse2", &X86.HasSSE2}, + } + + // Remove sse2 from options on amd64(p32) because SSE2 is a mandatory feature for these GOARCHs. + if GOARCH == "amd64" || GOARCH == "amd64p32" { + options = options[:len(options)-1] + } + + maxID, _, _, _ := cpuid(0, 0) + + if maxID < 1 { + return + } + + _, _, ecx1, edx1 := cpuid(1, 0) + X86.HasSSE2 = isSet(edx1, cpuid_SSE2) + + X86.HasSSE3 = isSet(ecx1, cpuid_SSE3) + X86.HasPCLMULQDQ = isSet(ecx1, cpuid_PCLMULQDQ) + X86.HasSSSE3 = isSet(ecx1, cpuid_SSSE3) + X86.HasFMA = isSet(ecx1, cpuid_FMA) + X86.HasSSE41 = isSet(ecx1, cpuid_SSE41) + X86.HasSSE42 = isSet(ecx1, cpuid_SSE42) + X86.HasPOPCNT = isSet(ecx1, cpuid_POPCNT) + X86.HasAES = isSet(ecx1, cpuid_AES) + X86.HasCMPXCHG16B = isSet(ecx1, cpuid_CMPXCHG16B) + X86.HasOSXSAVE = isSet(ecx1, cpuid_OSXSAVE) + + osSupportsAVX := false + osSupportsAVX512 := false + // For XGETBV, OSXSAVE bit is required and sufficient. + if X86.HasOSXSAVE { + eax, _ := xgetbv() + // Check if XMM and YMM registers have OS support. + osSupportsAVX = isSet(eax, 1<<1) && isSet(eax, 1<<2) + // Check is ZMM registers have OS support. + osSupportsAVX512 = isSet(eax>>5, 7) && isSet(eax>>1, 3) + } + + X86.HasAVX = isSet(ecx1, cpuid_AVX) && osSupportsAVX + + if maxID < 7 { + return + } + + _, ebx7, _, _ := cpuid(7, 0) + X86.HasBMI1 = isSet(ebx7, cpuid_BMI1) + X86.HasAVX2 = isSet(ebx7, cpuid_AVX2) && osSupportsAVX + X86.HasAVX512F = isSet(ebx7, cpuid_AVX512F) && osSupportsAVX512 + X86.HasAVX512DQ = isSet(ebx7, cpuid_AVX512DQ) && osSupportsAVX512 + X86.HasAVX512BW = isSet(ebx7, cpuid_AVX512BW) && osSupportsAVX512 + X86.HasAVX512VL = isSet(ebx7, cpuid_AVX512VL) && osSupportsAVX512 + X86.HasBMI2 = isSet(ebx7, cpuid_BMI2) + X86.HasERMS = isSet(ebx7, cpuid_ERMS) + X86.HasADX = isSet(ebx7, cpuid_ADX) + + X86.Cache = getCacheSize() + + X86.HasInvariantTSC = hasInvariantTSC() + + X86.Family, X86.Model, X86.SteppingID = getVersionInfo() + + X86.Signature = makeSignature(X86.Family, X86.Model) + + X86.Name = getName() + + X86.TSCFrequency = getNativeTSCFrequency(X86.Name, X86.Signature, X86.SteppingID) +} + +func isSet(hwc uint32, value uint32) bool { + return hwc&value != 0 +} + +func hasInvariantTSC() bool { + if maxExtendedFunction() < 0x80000007 { + return false + } + _, _, _, edx := cpuid(0x80000007, 0) + return isSet(edx, cpuid_Invariant_TSC) +} + +func getName() string { + if maxExtendedFunction() >= 0x80000004 { + v := make([]uint32, 0, 48) + for i := uint32(0); i < 3; i++ { + a, b, c, d := cpuid(0x80000002+i, 0) + v = append(v, a, b, c, d) + } + return strings.Trim(string(valAsString(v...)), " ") + } + return "unknown" +} + +// getNativeTSCFrequency gets TSC frequency from CPUID, +// only supports Intel (Skylake or later microarchitecture) & key information is from Intel manual & kernel codes +// (especially this commit: https://github.com/torvalds/linux/commit/604dc9170f2435d27da5039a3efd757dceadc684). +func getNativeTSCFrequency(name, sign string, steppingID uint32) uint64 { + + if vendorID() != Intel { + return 0 + } + + if maxFunctionID() < 0x15 { + return 0 + } + + // ApolloLake, GeminiLake, CannonLake (and presumably all new chipsets + // from this point) report the crystal frequency directly via CPUID.0x15. + // That's definitive data that we can rely upon. + eax, ebx, ecx, _ := cpuid(0x15, 0) + + // If ebx is 0, the TSC/”core crystal clock” ratio is not enumerated. + // We won't provide TSC frequency detection in this situation. + if eax == 0 || ebx == 0 { + return 0 + } + + // Skylake, Kabylake and all variants of those two chipsets report a + // crystal frequency of zero. + if ecx == 0 { // Crystal clock frequency is not enumerated. + ecx = getCrystalClockFrequency(sign, steppingID) + } + + // TSC frequency = “core crystal clock frequency” * EBX/EAX. + return uint64(ecx) * (uint64(ebx) / uint64(eax)) +} + +// Copied from: CPUID Signature values of DisplayFamily and DisplayModel, +// in Intel® 64 and IA-32 Architectures Software Developer’s Manual +// Volume 4: Model-Specific Registers +// & https://github.com/torvalds/linux/blob/master/arch/x86/include/asm/intel-family.h +const ( + IntelFam6SkylakeL = "06_4EH" + IntelFam6Skylake = "06_5EH" + IntelFam6XeonScalable = "06_55H" + IntelFam6KabylakeL = "06_8EH" + IntelFam6Kabylake = "06_9EH" +) + +// getCrystalClockFrequency gets crystal clock frequency +// for Intel processors in which CPUID.15H.EBX[31:0] ÷ CPUID.0x15.EAX[31:0] is enumerated +// but CPUID.15H.ECX is not enumerated using this function to get nominal core crystal clock frequency. +// +// Actually these crystal clock frequencies provided by Intel hardcoded tables are not so accurate in some cases, +// e.g. SkyLake server CPU may have issue (All SKX subject the crystal to an EMI reduction circuit that +//reduces its actual frequency by (approximately) -0.25%): +// see https://lore.kernel.org/lkml/ff6dcea166e8ff8f2f6a03c17beab2cb436aa779.1513920414.git.len.brown@intel.com/ +// for more details. +// With this report, I set a coefficient (0.9975) for IntelFam6SkyLakeX. +// +// Unlike the kernel way (mentioned in https://github.com/torvalds/linux/commit/604dc9170f2435d27da5039a3efd757dceadc684), +// I prefer the Intel hardcoded tables, (in +// 18.7.3 Determining the Processor Base Frequency, Table 18-85. Nominal Core Crystal Clock Frequency) +// because after some testing (comparing with wall clock, see https://github.com/templexxx/tsc/tsc_test.go for more details), +// I found hardcoded tables are more accurate. +func getCrystalClockFrequency(sign string, steppingID uint32) uint32 { + + if maxFunctionID() < 0x16 { + return 0 + } + + switch sign { + case IntelFam6SkylakeL: + return 24 * 1000 * 1000 + case IntelFam6Skylake: + return 24 * 1000 * 1000 + case IntelFam6XeonScalable: + // SKL-SP. + // see: https://community.intel.com/t5/Software-Tuning-Performance/How-to-detect-microarchitecture-on-Xeon-Scalable/m-p/1205162#M7633. + if steppingID == 0x2 || steppingID == 0x3 || steppingID == 0x4 { + return 25 * 1000 * 1000 * 0.9975 + } + return 25 * 1000 * 1000 // TODO check other Xeon Scalable has no slow down issue. + case IntelFam6KabylakeL: + return 24 * 1000 * 1000 + case IntelFam6Kabylake: + return 24 * 1000 * 1000 + } + + return 0 +} + +func getVersionInfo() (uint32, uint32, uint32) { + if maxFunctionID() < 0x1 { + return 0, 0, 0 + } + eax, _, _, _ := cpuid(1, 0) + family := (eax >> 8) & 0xf + displayFamily := family + if family == 0xf { + displayFamily = ((eax >> 20) & 0xff) + family + } + model := (eax >> 4) & 0xf + displayModel := model + if family == 0x6 || family == 0xf { + displayModel = ((eax >> 12) & 0xf0) + model + } + return displayFamily, displayModel, eax & 0x7 +} + +// signature format: XX_XXH +func makeSignature(family, model uint32) string { + signature := strings.ToUpper(fmt.Sprintf("0%x_0%xH", family, model)) + ss := strings.Split(signature, "_") + for i, s := range ss { + // Maybe insert too more `0`, drop it. + if len(s) > 2 { + s = s[1:] + ss[i] = s + } + } + return strings.Join(ss, "_") +} + +// getCacheSize is from +// https://github.com/klauspost/cpuid/blob/5a626f7029c910cc8329dae5405ee4f65034bce5/cpuid.go#L723 +func getCacheSize() Cache { + c := Cache{ + L1I: -1, + L1D: -1, + L2: -1, + L3: -1, + } + + vendor := vendorID() + switch vendor { + case Intel: + if maxFunctionID() < 4 { + return c + } + for i := uint32(0); ; i++ { + eax, ebx, ecx, _ := cpuid(4, i) + cacheType := eax & 15 + if cacheType == 0 { + break + } + cacheLevel := (eax >> 5) & 7 + coherency := int(ebx&0xfff) + 1 + partitions := int((ebx>>12)&0x3ff) + 1 + associativity := int((ebx>>22)&0x3ff) + 1 + sets := int(ecx) + 1 + size := associativity * partitions * coherency * sets + switch cacheLevel { + case 1: + if cacheType == 1 { + // 1 = Data Cache + c.L1D = size + } else if cacheType == 2 { + // 2 = Instruction Cache + c.L1I = size + } else { + if c.L1D < 0 { + c.L1I = size + } + if c.L1I < 0 { + c.L1I = size + } + } + case 2: + c.L2 = size + case 3: + c.L3 = size + } + } + case AMD, Hygon: + // Untested. + if maxExtendedFunction() < 0x80000005 { + return c + } + _, _, ecx, edx := cpuid(0x80000005, 0) + c.L1D = int(((ecx >> 24) & 0xFF) * 1024) + c.L1I = int(((edx >> 24) & 0xFF) * 1024) + + if maxExtendedFunction() < 0x80000006 { + return c + } + _, _, ecx, _ = cpuid(0x80000006, 0) + c.L2 = int(((ecx >> 16) & 0xFFFF) * 1024) + } + + return c +} + +func maxFunctionID() uint32 { + a, _, _, _ := cpuid(0, 0) + return a +} + +func maxExtendedFunction() uint32 { + eax, _, _, _ := cpuid(0x80000000, 0) + return eax +} + +const ( + Other = iota + Intel + AMD + VIA + Transmeta + NSC + KVM // Kernel-based Virtual Machine + MSVM // Microsoft Hyper-V or Windows Virtual PC + VMware + XenHVM + Bhyve + Hygon +) + +// Except from http://en.wikipedia.org/wiki/CPUID#EAX.3D0:_Get_vendor_ID +var vendorMapping = map[string]int{ + "AMDisbetter!": AMD, + "AuthenticAMD": AMD, + "CentaurHauls": VIA, + "GenuineIntel": Intel, + "TransmetaCPU": Transmeta, + "GenuineTMx86": Transmeta, + "Geode by NSC": NSC, + "VIA VIA VIA ": VIA, + "KVMKVMKVMKVM": KVM, + "Microsoft Hv": MSVM, + "VMwareVMware": VMware, + "XenVMMXenVMM": XenHVM, + "bhyve bhyve ": Bhyve, + "HygonGenuine": Hygon, +} + +func vendorID() int { + _, b, c, d := cpuid(0, 0) + v := valAsString(b, d, c) + vend, ok := vendorMapping[string(v)] + if !ok { + return Other + } + return vend +} + +func valAsString(values ...uint32) []byte { + r := make([]byte, 4*len(values)) + for i, v := range values { + dst := r[i*4:] + dst[0] = byte(v & 0xff) + dst[1] = byte((v >> 8) & 0xff) + dst[2] = byte((v >> 16) & 0xff) + dst[3] = byte((v >> 24) & 0xff) + switch { + case dst[0] == 0: + return r[:i*4] + case dst[1] == 0: + return r[:i*4+1] + case dst[2] == 0: + return r[:i*4+2] + case dst[3] == 0: + return r[:i*4+3] + } + } + return r +} diff --git a/vendor/github.com/templexxx/cpu/cpu_x86.s b/vendor/github.com/templexxx/cpu/cpu_x86.s new file mode 100644 index 00000000..228fbcf6 --- /dev/null +++ b/vendor/github.com/templexxx/cpu/cpu_x86.s @@ -0,0 +1,32 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 + +#include "textflag.h" + +// func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) +TEXT ·cpuid(SB), NOSPLIT, $0-24 + MOVL eaxArg+0(FP), AX + MOVL ecxArg+4(FP), CX + CPUID + MOVL AX, eax+8(FP) + MOVL BX, ebx+12(FP) + MOVL CX, ecx+16(FP) + MOVL DX, edx+20(FP) + RET + +// func xgetbv() (eax, edx uint32) +TEXT ·xgetbv(SB),NOSPLIT,$0-8 +#ifdef GOOS_nacl + // nacl does not support XGETBV. + MOVL $0, eax+0(FP) + MOVL $0, edx+4(FP) +#else + MOVL $0, CX + XGETBV + MOVL AX, eax+0(FP) + MOVL DX, edx+4(FP) +#endif + RET diff --git a/vendor/github.com/templexxx/xorsimd/.gitattributes b/vendor/github.com/templexxx/xorsimd/.gitattributes new file mode 100644 index 00000000..68f7d043 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/.gitattributes @@ -0,0 +1 @@ +*.s linguist-language=go:x diff --git a/vendor/github.com/templexxx/xorsimd/.github/workflows/unit-test.yml b/vendor/github.com/templexxx/xorsimd/.github/workflows/unit-test.yml new file mode 100644 index 00000000..1f8d8850 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/.github/workflows/unit-test.yml @@ -0,0 +1,36 @@ +name: unit-test + +on: + push: + branches: + - master + - release/* + pull_request: + branches: + - master + +jobs: + + test: + name: Test + runs-on: ubuntu-latest + steps: + + - name: Set up Go 1.13 + uses: actions/setup-go@v1 + with: + go-version: 1.13 + id: go + + - name: Check out code into the Go module directory + uses: actions/checkout@v1 + + - name: Get dependencies + run: | + go get -v -t -d ./... + if [ -f Gopkg.toml ]; then + curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh + dep ensure + fi + - name: Run test + run: CGO_ENABLED=1 GO111MODULE=on go test -v -race diff --git a/vendor/github.com/templexxx/xorsimd/.gitignore b/vendor/github.com/templexxx/xorsimd/.gitignore new file mode 100644 index 00000000..43309f8b --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/.gitignore @@ -0,0 +1,13 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out +.idea diff --git a/vendor/github.com/templexxx/xorsimd/LICENSE b/vendor/github.com/templexxx/xorsimd/LICENSE new file mode 100644 index 00000000..08ee7141 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Temple3x (temple3x@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/templexxx/xorsimd/README.md b/vendor/github.com/templexxx/xorsimd/README.md new file mode 100644 index 00000000..9dce5c9c --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/README.md @@ -0,0 +1,46 @@ +# XOR SIMD + +[![GoDoc][1]][2] [![MIT licensed][3]][4] [![Build Status][5]][6] [![Go Report Card][7]][8] [![Sourcegraph][9]][10] + +[1]: https://godoc.org/github.com/templexxx/xorsimd?status.svg +[2]: https://godoc.org/github.com/templexxx/xorsimd +[3]: https://img.shields.io/badge/license-MIT-blue.svg +[4]: LICENSE +[5]: https://github.com/templexxx/xorsimd/workflows/unit-test/badge.svg +[6]: https://github.com/templexxx/xorsimd +[7]: https://goreportcard.com/badge/github.com/templexxx/xorsimd +[8]: https://goreportcard.com/report/github.com/templexxx/xorsimd +[9]: https://sourcegraph.com/github.com/templexxx/xorsimd/-/badge.svg +[10]: https://sourcegraph.com/github.com/templexxx/xorsimd?badge + +## Introduction: + +>- XOR code engine in pure Go. +> +>- [High Performance](https://github.com/templexxx/xorsimd#performance): +More than 270GB/s per physics core. + +## Performance + +Performance depends mainly on: + +>- CPU instruction extension. +> +>- Number of source row vectors. + +**Platform:** + +*AWS c5d.xlarge (Intel(R) Xeon(R) Platinum 8124M CPU @ 3.00GHz)* + +**All test run on a single Core.** + +`I/O = (src_num + 1) * vector_size / cost` + +| Src Num | Vector size | AVX512 I/O (MB/S) | AVX2 I/O (MB/S) |SSE2 I/O (MB/S) | +|-------|-------------|-------------|---------------|---------------| +|5|4KB| 270403.73 | 142825.25 | 74443.91 | +|5|1MB| 26948.34 | 26887.37 | 26950.65 | +|5|8MB| 17881.32 | 17212.56 | 16402.97 | +|10|4KB| 190445.30 | 102953.59 | 53244.04 | +|10|1MB| 26424.44 | 26618.65 | 26094.39 | +|10|8MB| 15471.31 | 14866.72 | 13565.80 | diff --git a/vendor/github.com/templexxx/xorsimd/go.mod b/vendor/github.com/templexxx/xorsimd/go.mod new file mode 100644 index 00000000..ac5f57fc --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/go.mod @@ -0,0 +1,5 @@ +module github.com/templexxx/xorsimd + +require github.com/templexxx/cpu v0.0.1 + +go 1.13 diff --git a/vendor/github.com/templexxx/xorsimd/go.sum b/vendor/github.com/templexxx/xorsimd/go.sum new file mode 100644 index 00000000..04d04de8 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/go.sum @@ -0,0 +1,2 @@ +github.com/templexxx/cpu v0.0.1 h1:hY4WdLOgKdc8y13EYklu9OUTXik80BkxHoWvTO6MQQY= +github.com/templexxx/cpu v0.0.1/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= diff --git a/vendor/github.com/templexxx/xorsimd/xor.go b/vendor/github.com/templexxx/xorsimd/xor.go new file mode 100644 index 00000000..ae88911d --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/xor.go @@ -0,0 +1,89 @@ +// Copyright (c) 2019. Temple3x (temple3x@gmail.com) +// +// Use of this source code is governed by the MIT License +// that can be found in the LICENSE file. + +package xorsimd + +import "github.com/templexxx/cpu" + +// EnableAVX512 may slow down CPU Clock (maybe not). +// TODO need more research: +// https://lemire.me/blog/2018/04/19/by-how-much-does-avx-512-slow-down-your-cpu-a-first-experiment/ +var EnableAVX512 = true + +// cpuFeature indicates which instruction set will be used. +var cpuFeature = getCPUFeature() + +const ( + avx512 = iota + avx2 + sse2 + generic +) + +// TODO: Add ARM feature... +func getCPUFeature() int { + if hasAVX512() && EnableAVX512 { + return avx512 + } else if cpu.X86.HasAVX2 { + return avx2 + } else { + return sse2 // amd64 must has sse2 + } +} + +func hasAVX512() (ok bool) { + + return cpu.X86.HasAVX512VL && + cpu.X86.HasAVX512BW && + cpu.X86.HasAVX512F && + cpu.X86.HasAVX512DQ +} + +// Encode encodes elements from source slice into a +// destination slice. The source and destination may overlap. +// Encode returns the number of bytes encoded, which will be the minimum of +// len(src[i]) and len(dst). +func Encode(dst []byte, src [][]byte) (n int) { + n = checkLen(dst, src) + if n == 0 { + return + } + + dst = dst[:n] + for i := range src { + src[i] = src[i][:n] + } + + if len(src) == 1 { + copy(dst, src[0]) + return + } + + encode(dst, src) + return +} + +func checkLen(dst []byte, src [][]byte) int { + n := len(dst) + for i := range src { + if len(src[i]) < n { + n = len(src[i]) + } + } + + if n <= 0 { + return 0 + } + return n +} + +// Bytes XORs the bytes in a and b into a +// destination slice. The source and destination may overlap. +// +// Bytes returns the number of bytes encoded, which will be the minimum of +// len(dst), len(a), len(b). +func Bytes(dst, a, b []byte) int { + return Encode(dst, [][]byte{a, b}) +} diff --git a/vendor/github.com/templexxx/xorsimd/xor_amd64.go b/vendor/github.com/templexxx/xorsimd/xor_amd64.go new file mode 100644 index 00000000..5d46df35 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/xor_amd64.go @@ -0,0 +1,95 @@ +// Copyright (c) 2019. Temple3x (temple3x@gmail.com) +// +// Use of this source code is governed by the MIT License +// that can be found in the LICENSE file. + +package xorsimd + +func encode(dst []byte, src [][]byte) { + + switch cpuFeature { + case avx512: + encodeAVX512(dst, src) + case avx2: + encodeAVX2(dst, src) + default: + encodeSSE2(dst, src) + } + return +} + +// Bytes8 XORs of 8 Bytes. +// The slice arguments a, b, dst's lengths are assumed to be at least 8, +// if not, Bytes8 will panic. +func Bytes8(dst, a, b []byte) { + + bytes8(&dst[0], &a[0], &b[0]) +} + +// Bytes16 XORs of packed 16 Bytes. +// The slice arguments a, b, dst's lengths are assumed to be at least 16, +// if not, Bytes16 will panic. +func Bytes16(dst, a, b []byte) { + + bytes16(&dst[0], &a[0], &b[0]) +} + +// Bytes8Align XORs of 8 Bytes. +// The slice arguments a, b, dst's lengths are assumed to be at least 8, +// if not, Bytes8 will panic. +func Bytes8Align(dst, a, b []byte) { + + bytes8(&dst[0], &a[0], &b[0]) +} + +// Bytes16Align XORs of packed 16 Bytes. +// The slice arguments a, b, dst's lengths are assumed to be at least 16, +// if not, Bytes16 will panic. +func Bytes16Align(dst, a, b []byte) { + + bytes16(&dst[0], &a[0], &b[0]) +} + +// BytesA XORs the len(a) bytes in a and b into a +// destination slice. +// The destination should have enough space. +// +// It's used for encoding small bytes slices (< dozens bytes), +// and the slices may not be aligned to 8 bytes or 16 bytes. +// If the length is big, it's better to use 'func Bytes(dst, a, b []byte)' instead +// for gain better performance. +func BytesA(dst, a, b []byte) { + + bytesN(&dst[0], &a[0], &b[0], len(a)) +} + +// BytesB XORs the len(b) bytes in a and b into a +// destination slice. +// The destination should have enough space. +// +// It's used for encoding small bytes slices (< dozens bytes), +// and the slices may not be aligned to 8 bytes or 16 bytes. +// If the length is big, it's better to use 'func Bytes(dst, a, b []byte)' instead +// for gain better performance. +func BytesB(dst, a, b []byte) { + + bytesN(&dst[0], &a[0], &b[0], len(b)) +} + +//go:noescape +func encodeAVX512(dst []byte, src [][]byte) + +//go:noescape +func encodeAVX2(dst []byte, src [][]byte) + +//go:noescape +func encodeSSE2(dst []byte, src [][]byte) + +//go:noescape +func bytesN(dst, a, b *byte, n int) + +//go:noescape +func bytes8(dst, a, b *byte) + +//go:noescape +func bytes16(dst, a, b *byte) diff --git a/vendor/github.com/templexxx/xorsimd/xor_generic.go b/vendor/github.com/templexxx/xorsimd/xor_generic.go new file mode 100644 index 00000000..b12908f8 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/xor_generic.go @@ -0,0 +1,205 @@ +// Copyright (c) 2019. Temple3x (temple3x@gmail.com) +// +// Use of this source code is governed by the MIT License +// that can be found in the LICENSE file. +// +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !amd64 + +package xorsimd + +import ( + "runtime" + "unsafe" +) + +const wordSize = int(unsafe.Sizeof(uintptr(0))) +const supportsUnaligned = runtime.GOARCH == "386" || runtime.GOARCH == "ppc64" || runtime.GOARCH == "ppc64le" || runtime.GOARCH == "s390x" + +func encode(dst []byte, src [][]byte) { + if supportsUnaligned { + fastEncode(dst, src, len(dst)) + } else { + // TODO(hanwen): if (dst, a, b) have common alignment + // we could still try fastEncode. It is not clear + // how often this happens, and it's only worth it if + // the block encryption itself is hardware + // accelerated. + safeEncode(dst, src, len(dst)) + } + +} + +// fastEncode xor in bulk. It only works on architectures that +// support unaligned read/writes. +func fastEncode(dst []byte, src [][]byte, n int) { + w := n / wordSize + if w > 0 { + wordBytes := w * wordSize + + wordAlignSrc := make([][]byte, len(src)) + for i := range src { + wordAlignSrc[i] = src[i][:wordBytes] + } + fastEnc(dst[:wordBytes], wordAlignSrc) + } + + for i := n - n%wordSize; i < n; i++ { + s := src[0][i] + for j := 1; j < len(src); j++ { + s ^= src[j][i] + } + dst[i] = s + } +} + +func fastEnc(dst []byte, src [][]byte) { + dw := *(*[]uintptr)(unsafe.Pointer(&dst)) + sw := make([][]uintptr, len(src)) + for i := range src { + sw[i] = *(*[]uintptr)(unsafe.Pointer(&src[i])) + } + + n := len(dst) / wordSize + for i := 0; i < n; i++ { + s := sw[0][i] + for j := 1; j < len(sw); j++ { + s ^= sw[j][i] + } + dw[i] = s + } +} + +func safeEncode(dst []byte, src [][]byte, n int) { + for i := 0; i < n; i++ { + s := src[0][i] + for j := 1; j < len(src); j++ { + s ^= src[j][i] + } + dst[i] = s + } +} + +// Bytes8 XORs of word 8 Bytes. +// The slice arguments a, b, dst's lengths are assumed to be at least 8, +// if not, Bytes8 will panic. +func Bytes8(dst, a, b []byte) { + + bytesWords(dst[:8], a[:8], b[:8]) +} + +// Bytes16 XORs of packed doubleword 16 Bytes. +// The slice arguments a, b, dst's lengths are assumed to be at least 16, +// if not, Bytes16 will panic. +func Bytes16(dst, a, b []byte) { + + bytesWords(dst[:16], a[:16], b[:16]) +} + +// bytesWords XORs multiples of 4 or 8 bytes (depending on architecture.) +// The slice arguments a and b are assumed to be of equal length. +func bytesWords(dst, a, b []byte) { + if supportsUnaligned { + dw := *(*[]uintptr)(unsafe.Pointer(&dst)) + aw := *(*[]uintptr)(unsafe.Pointer(&a)) + bw := *(*[]uintptr)(unsafe.Pointer(&b)) + n := len(b) / wordSize + for i := 0; i < n; i++ { + dw[i] = aw[i] ^ bw[i] + } + } else { + n := len(b) + for i := 0; i < n; i++ { + dst[i] = a[i] ^ b[i] + } + } +} + +// Bytes8Align XORs of 8 Bytes. +// The slice arguments a, b, dst's lengths are assumed to be at least 8, +// if not, Bytes8 will panic. +// +// All the byte slices must be aligned to wordsize. +func Bytes8Align(dst, a, b []byte) { + + bytesWordsAlign(dst[:8], a[:8], b[:8]) +} + +// Bytes16Align XORs of packed 16 Bytes. +// The slice arguments a, b, dst's lengths are assumed to be at least 16, +// if not, Bytes16 will panic. +// +// All the byte slices must be aligned to wordsize. +func Bytes16Align(dst, a, b []byte) { + + bytesWordsAlign(dst[:16], a[:16], b[:16]) +} + +// bytesWordsAlign XORs multiples of 4 or 8 bytes (depending on architecture.) +// The slice arguments a and b are assumed to be of equal length. +// +// All the byte slices must be aligned to wordsize. +func bytesWordsAlign(dst, a, b []byte) { + dw := *(*[]uintptr)(unsafe.Pointer(&dst)) + aw := *(*[]uintptr)(unsafe.Pointer(&a)) + bw := *(*[]uintptr)(unsafe.Pointer(&b)) + n := len(b) / wordSize + for i := 0; i < n; i++ { + dw[i] = aw[i] ^ bw[i] + } +} + +// BytesA XORs the len(a) bytes in a and b into a +// destination slice. +// The destination should have enough space. +// +// It's used for encoding small bytes slices (< dozens bytes), +// and the slices may not be aligned to 8 bytes or 16 bytes. +// If the length is big, it's better to use 'func Bytes(dst, a, b []byte)' instead +// for gain better performance. +func BytesA(dst, a, b []byte) { + + n := len(a) + bytesN(dst[:n], a[:n], b[:n], n) +} + +// BytesB XORs the len(b) bytes in a and b into a +// destination slice. +// The destination should have enough space. +// +// It's used for encoding small bytes slices (< dozens bytes), +// and the slices may not be aligned to 8 bytes or 16 bytes. +// If the length is big, it's better to use 'func Bytes(dst, a, b []byte)' instead +// for gain better performance. +func BytesB(dst, a, b []byte) { + + n := len(b) + bytesN(dst[:n], a[:n], b[:n], n) +} + +func bytesN(dst, a, b []byte, n int) { + + switch { + case supportsUnaligned: + w := n / wordSize + if w > 0 { + dw := *(*[]uintptr)(unsafe.Pointer(&dst)) + aw := *(*[]uintptr)(unsafe.Pointer(&a)) + bw := *(*[]uintptr)(unsafe.Pointer(&b)) + for i := 0; i < w; i++ { + dw[i] = aw[i] ^ bw[i] + } + } + + for i := (n - n%wordSize); i < n; i++ { + dst[i] = a[i] ^ b[i] + } + default: + for i := 0; i < n; i++ { + dst[i] = a[i] ^ b[i] + } + } +} diff --git a/vendor/github.com/templexxx/xorsimd/xor_test.go b/vendor/github.com/templexxx/xorsimd/xor_test.go new file mode 100644 index 00000000..fdcbd957 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/xor_test.go @@ -0,0 +1,480 @@ +// Copyright (c) 2019. Temple3x (temple3x@gmail.com) +// +// Use of this source code is governed by the MIT License +// that can be found in the LICENSE file. +// +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// TestEncodeBytes is copied from Go Standard lib: +// crypto/cipher/xor_test.go + +package xorsimd + +import ( + "bytes" + "fmt" + "math/rand" + "testing" + "time" + "unsafe" +) + +const ( + kb = 1024 + mb = 1024 * 1024 + + testSize = kb +) + +func TestBytes8(t *testing.T) { + + rand.Seed(time.Now().UnixNano()) + + for j := 0; j < 1024; j++ { + a := make([]byte, 8) + b := make([]byte, 8) + fillRandom(a) + fillRandom(b) + + dst0 := make([]byte, 8) + Bytes8(dst0, a, b) + + dst1 := make([]byte, 8) + for i := 0; i < 8; i++ { + dst1[i] = a[i] ^ b[i] + } + + if !bytes.Equal(dst0, dst1) { + t.Fatal("not equal", a, b, dst0, dst1) + } + } +} + +func TestBytes16(t *testing.T) { + + rand.Seed(time.Now().UnixNano()) + + for j := 0; j < 1024; j++ { + a := make([]byte, 16) + b := make([]byte, 16) + fillRandom(a) + fillRandom(b) + + dst0 := make([]byte, 16) + Bytes16(dst0, a, b) + + dst1 := make([]byte, 16) + for i := 0; i < 16; i++ { + dst1[i] = a[i] ^ b[i] + } + + if !bytes.Equal(dst0, dst1) { + t.Fatal("not equal", dst0, dst1, a, b) + } + } +} + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func TestBytes8Align(t *testing.T) { + + rand.Seed(time.Now().UnixNano()) + + for j := 0; j < 1024; j++ { + a := make([]byte, 8+wordSize) + b := make([]byte, 8+wordSize) + dst0 := make([]byte, 8+wordSize) + dst1 := make([]byte, 8+wordSize) + + al := alignment(a) + offset := 0 + if al != 0 { + offset = wordSize - al + } + a = a[offset : offset+8] + + al = alignment(b) + offset = 0 + if al != 0 { + offset = wordSize - al + } + b = b[offset : offset+8] + + al = alignment(dst0) + offset = 0 + if al != 0 { + offset = wordSize - al + } + dst0 = dst0[offset : offset+8] + + al = alignment(dst1) + offset = 0 + if al != 0 { + offset = wordSize - al + } + dst1 = dst1[offset : offset+8] + + fillRandom(a) + fillRandom(b) + + Bytes8Align(dst0, a, b) + + for i := 0; i < 8; i++ { + dst1[i] = a[i] ^ b[i] + } + + if !bytes.Equal(dst0, dst1) { + t.Fatal("not equal", a, b, dst0, dst1) + } + } +} + +func alignment(s []byte) int { + return int(uintptr(unsafe.Pointer(&s[0])) & uintptr(wordSize-1)) +} + +func TestBytes16Align(t *testing.T) { + + rand.Seed(time.Now().UnixNano()) + + for j := 0; j < 1024; j++ { + a := make([]byte, 16+wordSize) + b := make([]byte, 16+wordSize) + dst0 := make([]byte, 16+wordSize) + dst1 := make([]byte, 16+wordSize) + + al := alignment(a) + offset := 0 + if al != 0 { + offset = wordSize - al + } + a = a[offset : offset+16] + + al = alignment(b) + offset = 0 + if al != 0 { + offset = wordSize - al + } + b = b[offset : offset+16] + + al = alignment(dst0) + offset = 0 + if al != 0 { + offset = wordSize - al + } + dst0 = dst0[offset : offset+16] + + al = alignment(dst1) + offset = 0 + if al != 0 { + offset = wordSize - al + } + dst1 = dst1[offset : offset+16] + + fillRandom(a) + fillRandom(b) + + Bytes16Align(dst0, a, b) + + for i := 0; i < 16; i++ { + dst1[i] = a[i] ^ b[i] + } + + if !bytes.Equal(dst0, dst1) { + t.Fatal("not equal", a, b, dst0, dst1) + } + } +} + +func TestBytesA(t *testing.T) { + + rand.Seed(time.Now().UnixNano()) + + for j := 2; j <= 1024; j++ { + + for alignP := 0; alignP < 2; alignP++ { + p := make([]byte, j)[alignP:] + q := make([]byte, j) + d1 := make([]byte, j) + d2 := make([]byte, j) + + fillRandom(p) + fillRandom(q) + + BytesA(d1, p, q) + for i := 0; i < j-alignP; i++ { + d2[i] = p[i] ^ q[i] + } + if !bytes.Equal(d1, d2) { + t.Fatal("not equal") + } + } + } +} + +func TestBytesB(t *testing.T) { + + rand.Seed(time.Now().UnixNano()) + + for j := 2; j <= 1024; j++ { + + for alignQ := 0; alignQ < 2; alignQ++ { + p := make([]byte, j) + q := make([]byte, j)[alignQ:] + d1 := make([]byte, j) + d2 := make([]byte, j) + + fillRandom(p) + fillRandom(q) + + BytesB(d1, p, q) + for i := 0; i < j-alignQ; i++ { + d2[i] = p[i] ^ q[i] + } + if !bytes.Equal(d1, d2) { + t.Fatal("not equal") + } + } + } +} + +func TestBytes(t *testing.T) { + + rand.Seed(time.Now().UnixNano()) + + for j := 1; j <= 1024; j++ { + + for alignP := 0; alignP < 2; alignP++ { + for alignQ := 0; alignQ < 2; alignQ++ { + for alignD := 0; alignD < 2; alignD++ { + p := make([]byte, j)[alignP:] + q := make([]byte, j)[alignQ:] + d1 := make([]byte, j)[alignD:] + d2 := make([]byte, j)[alignD:] + + fillRandom(p) + fillRandom(q) + + Bytes(d1, p, q) + n := min(p, q, d1) + for i := 0; i < n; i++ { + d2[i] = p[i] ^ q[i] + } + if !bytes.Equal(d1, d2) { + t.Fatal("not equal") + } + } + } + } + } +} + +func min(a, b, c []byte) int { + n := len(a) + if len(b) < n { + n = len(b) + } + if len(c) < n { + n = len(c) + } + return n +} + +func TestEncodeWithFeature(t *testing.T) { + max := testSize + + switch getCPUFeature() { + case avx512: + testEncode(t, max, sse2, -1) + testEncode(t, max, avx2, sse2) + testEncode(t, max, avx512, avx2) + case avx2: + testEncode(t, max, sse2, -1) + testEncode(t, max, avx2, sse2) + case sse2: + testEncode(t, max, sse2, -1) + case generic: + testEncode(t, max, generic, -1) + } +} + +func testEncode(t *testing.T, maxSize, feat, cmpFeat int) { + + rand.Seed(time.Now().UnixNano()) + srcN := randIntn(10, 2) // Cannot be 1, see func encode(dst []byte, src [][]byte, feature int). + + fs := featToStr(feat) + for size := 1; size <= maxSize; size++ { + exp := make([]byte, size) + src := make([][]byte, srcN) + for j := 0; j < srcN; j++ { + src[j] = make([]byte, size) + fillRandom(src[j]) + } + + if cmpFeat < 0 { + encodeTested(exp, src) + } else { + cpuFeature = cmpFeat + Encode(exp, src) + } + + act := make([]byte, size) + cpuFeature = feat + Encode(act, src) + + if !bytes.Equal(exp, act) { + t.Fatalf("%s mismatched with %s, src_num: %d, size: %d", + fs, featToStr(cmpFeat), srcN, size) + } + } + + t.Logf("%s pass src_num:%d, max_size: %d", + fs, srcN, maxSize) +} + +func featToStr(f int) string { + switch f { + case avx512: + return "AVX512" + case avx2: + return "AVX2" + case sse2: + return "SSE2" + case generic: + return "Generic" + default: + return "Tested" + } +} + +func encodeTested(dst []byte, src [][]byte) { + + n := len(dst) + for i := 0; i < n; i++ { + s := src[0][i] + for j := 1; j < len(src); j++ { + s ^= src[j][i] + } + dst[i] = s + } +} + +// randIntn returns, as an int, a non-negative pseudo-random number in [min,n) +// from the default Source. +func randIntn(n, min int) int { + m := rand.Intn(n) + if m < min { + m = min + } + return m +} + +func BenchmarkBytes8(b *testing.B) { + s0 := make([]byte, 8) + s1 := make([]byte, 8) + fillRandom(s0) + fillRandom(s1) + dst0 := make([]byte, 8) + + b.ResetTimer() + b.SetBytes(8) + for i := 0; i < b.N; i++ { + Bytes8(dst0, s0, s1) + } +} + +func BenchmarkBytes16(b *testing.B) { + s0 := make([]byte, 16) + s1 := make([]byte, 16) + fillRandom(s0) + fillRandom(s1) + dst0 := make([]byte, 16) + + b.ResetTimer() + b.SetBytes(16) + for i := 0; i < b.N; i++ { + Bytes16(dst0, s0, s1) + } +} + +func BenchmarkBytesN_16Bytes(b *testing.B) { + s0 := make([]byte, 16) + s1 := make([]byte, 16) + fillRandom(s0) + fillRandom(s1) + dst0 := make([]byte, 16) + + b.ResetTimer() + b.SetBytes(16) + for i := 0; i < b.N; i++ { + BytesA(dst0, s0, s1) + } +} + +func BenchmarkEncode(b *testing.B) { + sizes := []int{4 * kb, mb, 8 * mb} + + srcNums := []int{5, 10} + + var feats []int + switch getCPUFeature() { + case avx512: + feats = append(feats, avx512) + feats = append(feats, avx2) + feats = append(feats, sse2) + case avx2: + feats = append(feats, avx2) + feats = append(feats, sse2) + case sse2: + feats = append(feats, sse2) + default: + feats = append(feats, generic) + } + + b.Run("", benchEncRun(benchEnc, srcNums, sizes, feats)) +} + +func benchEncRun(f func(*testing.B, int, int, int), srcNums, sizes, feats []int) func(*testing.B) { + return func(b *testing.B) { + for _, feat := range feats { + for _, srcNum := range srcNums { + for _, size := range sizes { + b.Run(fmt.Sprintf("(%d+1)-%s-%s", srcNum, byteToStr(size), featToStr(feat)), func(b *testing.B) { + f(b, srcNum, size, feat) + }) + } + } + } + } +} + +func benchEnc(b *testing.B, srcNum, size, feat int) { + dst := make([]byte, size) + src := make([][]byte, srcNum) + for i := 0; i < srcNum; i++ { + src[i] = make([]byte, size) + fillRandom(src[i]) + } + cpuFeature = feat + + b.SetBytes(int64((srcNum + 1) * size)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + encode(dst, src) + } +} + +func fillRandom(p []byte) { + rand.Read(p) +} + +func byteToStr(n int) string { + if n >= mb { + return fmt.Sprintf("%dMB", n/mb) + } + + return fmt.Sprintf("%dKB", n/kb) +} diff --git a/vendor/github.com/templexxx/xorsimd/xoravx2_amd64.s b/vendor/github.com/templexxx/xorsimd/xoravx2_amd64.s new file mode 100644 index 00000000..23cf924d --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/xoravx2_amd64.s @@ -0,0 +1,124 @@ +// Copyright (c) 2019. Temple3x (temple3x@gmail.com) +// +// Use of this source code is governed by the MIT License +// that can be found in the LICENSE file. + +#include "textflag.h" + +#define dst BX // parity's address +#define d2src SI // two-dimension src_slice's address +#define csrc CX // cnt of src +#define len DX // len of vect +#define pos R8 // job position in vect + +#define csrc_tmp R9 +#define d2src_off R10 +#define src_tmp R11 +#define not_aligned_len R12 +#define src_val0 R13 +#define src_val1 R14 + +// func encodeAVX2(dst []byte, src [][]byte) +TEXT ·encodeAVX2(SB), NOSPLIT, $0 + MOVQ d+0(FP), dst + MOVQ s+24(FP), d2src + MOVQ c+32(FP), csrc + MOVQ l+8(FP), len + TESTQ $127, len + JNZ not_aligned + +aligned: + MOVQ $0, pos + +loop128b: + MOVQ csrc, csrc_tmp // store src_cnt -> csrc_tmp + SUBQ $2, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp // get first src_vect's addr -> src_tmp + VMOVDQU (src_tmp)(pos*1), Y0 + VMOVDQU 32(src_tmp)(pos*1), Y1 + VMOVDQU 64(src_tmp)(pos*1), Y2 + VMOVDQU 96(src_tmp)(pos*1), Y3 + +next_vect: + ADDQ $24, d2src_off // len(slice) = 24 + MOVQ (d2src)(d2src_off*1), src_tmp // next data_vect + VMOVDQU (src_tmp)(pos*1), Y4 + VMOVDQU 32(src_tmp)(pos*1), Y5 + VMOVDQU 64(src_tmp)(pos*1), Y6 + VMOVDQU 96(src_tmp)(pos*1), Y7 + VPXOR Y4, Y0, Y0 + VPXOR Y5, Y1, Y1 + VPXOR Y6, Y2, Y2 + VPXOR Y7, Y3, Y3 + SUBQ $1, csrc_tmp + JGE next_vect + + VMOVDQU Y0, (dst)(pos*1) + VMOVDQU Y1, 32(dst)(pos*1) + VMOVDQU Y2, 64(dst)(pos*1) + VMOVDQU Y3, 96(dst)(pos*1) + + ADDQ $128, pos + CMPQ len, pos + JNE loop128b + VZEROUPPER + RET + +loop_1b: + MOVQ csrc, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + SUBQ $2, csrc_tmp + MOVB -1(src_tmp)(len*1), src_val0 // encode from the end of src + +next_vect_1b: + ADDQ $24, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + MOVB -1(src_tmp)(len*1), src_val1 + XORB src_val1, src_val0 + SUBQ $1, csrc_tmp + JGE next_vect_1b + + MOVB src_val0, -1(dst)(len*1) + SUBQ $1, len + TESTQ $7, len + JNZ loop_1b + + CMPQ len, $0 + JE ret + TESTQ $127, len + JZ aligned + +not_aligned: + TESTQ $7, len + JNE loop_1b + MOVQ len, not_aligned_len + ANDQ $127, not_aligned_len + +loop_8b: + MOVQ csrc, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + SUBQ $2, csrc_tmp + MOVQ -8(src_tmp)(len*1), src_val0 + +next_vect_8b: + ADDQ $24, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + MOVQ -8(src_tmp)(len*1), src_val1 + XORQ src_val1, src_val0 + SUBQ $1, csrc_tmp + JGE next_vect_8b + + MOVQ src_val0, -8(dst)(len*1) + SUBQ $8, len + SUBQ $8, not_aligned_len + JG loop_8b + + CMPQ len, $128 + JGE aligned + RET + +ret: + RET diff --git a/vendor/github.com/templexxx/xorsimd/xoravx512_amd64.s b/vendor/github.com/templexxx/xorsimd/xoravx512_amd64.s new file mode 100644 index 00000000..2ba6b756 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/xoravx512_amd64.s @@ -0,0 +1,124 @@ +// Copyright (c) 2019. Temple3x (temple3x@gmail.com) +// +// Use of this source code is governed by the MIT License +// that can be found in the LICENSE file. + +#include "textflag.h" + +#define dst BX // parity's address +#define d2src SI // two-dimension src_slice's address +#define csrc CX // cnt of src +#define len DX // len of vect +#define pos R8 // job position in vect + +#define csrc_tmp R9 +#define d2src_off R10 +#define src_tmp R11 +#define not_aligned_len R12 +#define src_val0 R13 +#define src_val1 R14 + +// func encodeAVX512(dst []byte, src [][]byte) +TEXT ·encodeAVX512(SB), NOSPLIT, $0 + MOVQ d+0(FP), dst + MOVQ src+24(FP), d2src + MOVQ c+32(FP), csrc + MOVQ l+8(FP), len + TESTQ $255, len + JNZ not_aligned + +aligned: + MOVQ $0, pos + +loop256b: + MOVQ csrc, csrc_tmp // store src_cnt -> csrc_tmp + SUBQ $2, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp // get first src_vect's addr -> src_tmp + VMOVDQU8 (src_tmp)(pos*1), Z0 + VMOVDQU8 64(src_tmp)(pos*1), Z1 + VMOVDQU8 128(src_tmp)(pos*1), Z2 + VMOVDQU8 192(src_tmp)(pos*1), Z3 + +next_vect: + ADDQ $24, d2src_off // len(slice) = 24 + MOVQ (d2src)(d2src_off*1), src_tmp // next data_vect + VMOVDQU8 (src_tmp)(pos*1), Z4 + VMOVDQU8 64(src_tmp)(pos*1), Z5 + VMOVDQU8 128(src_tmp)(pos*1), Z6 + VMOVDQU8 192(src_tmp)(pos*1), Z7 + VPXORQ Z4, Z0, Z0 + VPXORQ Z5, Z1, Z1 + VPXORQ Z6, Z2, Z2 + VPXORQ Z7, Z3, Z3 + SUBQ $1, csrc_tmp + JGE next_vect + + VMOVDQU8 Z0, (dst)(pos*1) + VMOVDQU8 Z1, 64(dst)(pos*1) + VMOVDQU8 Z2, 128(dst)(pos*1) + VMOVDQU8 Z3, 192(dst)(pos*1) + + ADDQ $256, pos + CMPQ len, pos + JNE loop256b + VZEROUPPER + RET + +loop_1b: + MOVQ csrc, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + SUBQ $2, csrc_tmp + MOVB -1(src_tmp)(len*1), src_val0 // encode from the end of src + +next_vect_1b: + ADDQ $24, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + MOVB -1(src_tmp)(len*1), src_val1 + XORB src_val1, src_val0 + SUBQ $1, csrc_tmp + JGE next_vect_1b + + MOVB src_val0, -1(dst)(len*1) + SUBQ $1, len + TESTQ $7, len + JNZ loop_1b + + CMPQ len, $0 + JE ret + TESTQ $255, len + JZ aligned + +not_aligned: + TESTQ $7, len + JNE loop_1b + MOVQ len, not_aligned_len + ANDQ $255, not_aligned_len + +loop_8b: + MOVQ csrc, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + SUBQ $2, csrc_tmp + MOVQ -8(src_tmp)(len*1), src_val0 + +next_vect_8b: + ADDQ $24, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + MOVQ -8(src_tmp)(len*1), src_val1 + XORQ src_val1, src_val0 + SUBQ $1, csrc_tmp + JGE next_vect_8b + + MOVQ src_val0, -8(dst)(len*1) + SUBQ $8, len + SUBQ $8, not_aligned_len + JG loop_8b + + CMPQ len, $256 + JGE aligned + RET + +ret: + RET diff --git a/vendor/github.com/templexxx/xorsimd/xorbytes_amd64.s b/vendor/github.com/templexxx/xorsimd/xorbytes_amd64.s new file mode 100644 index 00000000..8f67edd2 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/xorbytes_amd64.s @@ -0,0 +1,72 @@ +#include "textflag.h" + +// func bytesN(dst, a, b *byte, n int) +TEXT ·bytesN(SB), NOSPLIT, $0 + MOVQ d+0(FP), BX + MOVQ a+8(FP), SI + MOVQ b+16(FP), CX + MOVQ n+24(FP), DX + TESTQ $15, DX // AND 15 & len, if not zero jump to not_aligned. + JNZ not_aligned + +aligned: + MOVQ $0, AX // position in slices + +loop16b: + MOVOU (SI)(AX*1), X0 // XOR 16byte forwards. + MOVOU (CX)(AX*1), X1 + PXOR X1, X0 + MOVOU X0, (BX)(AX*1) + ADDQ $16, AX + CMPQ DX, AX + JNE loop16b + RET + +loop_1b: + SUBQ $1, DX // XOR 1byte backwards. + MOVB (SI)(DX*1), DI + MOVB (CX)(DX*1), AX + XORB AX, DI + MOVB DI, (BX)(DX*1) + TESTQ $7, DX // AND 7 & len, if not zero jump to loop_1b. + JNZ loop_1b + CMPQ DX, $0 // if len is 0, ret. + JE ret + TESTQ $15, DX // AND 15 & len, if zero jump to aligned. + JZ aligned + +not_aligned: + TESTQ $7, DX // AND $7 & len, if not zero jump to loop_1b. + JNE loop_1b + SUBQ $8, DX // XOR 8bytes backwards. + MOVQ (SI)(DX*1), DI + MOVQ (CX)(DX*1), AX + XORQ AX, DI + MOVQ DI, (BX)(DX*1) + CMPQ DX, $16 // if len is greater or equal 16 here, it must be aligned. + JGE aligned + +ret: + RET + +// func bytes8(dst, a, b *byte) +TEXT ·bytes8(SB), NOSPLIT, $0 + MOVQ d+0(FP), BX + MOVQ a+8(FP), SI + MOVQ b+16(FP), CX + MOVQ (SI), DI + MOVQ (CX), AX + XORQ AX, DI + MOVQ DI, (BX) + RET + +// func bytes16(dst, a, b *byte) +TEXT ·bytes16(SB), NOSPLIT, $0 + MOVQ d+0(FP), BX + MOVQ a+8(FP), SI + MOVQ b+16(FP), CX + MOVOU (SI), X0 + MOVOU (CX), X1 + PXOR X1, X0 + MOVOU X0, (BX) + RET diff --git a/vendor/github.com/templexxx/xorsimd/xorsse2_amd64.s b/vendor/github.com/templexxx/xorsimd/xorsse2_amd64.s new file mode 100644 index 00000000..38df9489 --- /dev/null +++ b/vendor/github.com/templexxx/xorsimd/xorsse2_amd64.s @@ -0,0 +1,123 @@ +// Copyright (c) 2019. Temple3x (temple3x@gmail.com) +// +// Use of this source code is governed by the MIT License +// that can be found in the LICENSE file. + +#include "textflag.h" + +#define dst BX // parity's address +#define d2src SI // two-dimension src_slice's address +#define csrc CX // cnt of src +#define len DX // len of vect +#define pos R8 // job position in vect + +#define csrc_tmp R9 +#define d2src_off R10 +#define src_tmp R11 +#define not_aligned_len R12 +#define src_val0 R13 +#define src_val1 R14 + +// func encodeSSE2(dst []byte, src [][]byte) +TEXT ·encodeSSE2(SB), NOSPLIT, $0 + MOVQ d+0(FP), dst + MOVQ src+24(FP), d2src + MOVQ c+32(FP), csrc + MOVQ l+8(FP), len + TESTQ $63, len + JNZ not_aligned + +aligned: + MOVQ $0, pos + +loop64b: + MOVQ csrc, csrc_tmp + SUBQ $2, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + MOVOU (src_tmp)(pos*1), X0 + MOVOU 16(src_tmp)(pos*1), X1 + MOVOU 32(src_tmp)(pos*1), X2 + MOVOU 48(src_tmp)(pos*1), X3 + +next_vect: + ADDQ $24, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + MOVOU (src_tmp)(pos*1), X4 + MOVOU 16(src_tmp)(pos*1), X5 + MOVOU 32(src_tmp)(pos*1), X6 + MOVOU 48(src_tmp)(pos*1), X7 + PXOR X4, X0 + PXOR X5, X1 + PXOR X6, X2 + PXOR X7, X3 + SUBQ $1, csrc_tmp + JGE next_vect + + MOVOU X0, (dst)(pos*1) + MOVOU X1, 16(dst)(pos*1) + MOVOU X2, 32(dst)(pos*1) + MOVOU X3, 48(dst)(pos*1) + + ADDQ $64, pos + CMPQ len, pos + JNE loop64b + RET + +loop_1b: + MOVQ csrc, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + SUBQ $2, csrc_tmp + MOVB -1(src_tmp)(len*1), src_val0 + +next_vect_1b: + ADDQ $24, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + MOVB -1(src_tmp)(len*1), src_val1 + XORB src_val1, src_val0 + SUBQ $1, csrc_tmp + JGE next_vect_1b + + MOVB src_val0, -1(dst)(len*1) + SUBQ $1, len + TESTQ $7, len + JNZ loop_1b + + CMPQ len, $0 + JE ret + TESTQ $63, len + JZ aligned + +not_aligned: + TESTQ $7, len + JNE loop_1b + MOVQ len, not_aligned_len + ANDQ $63, not_aligned_len + +loop_8b: + MOVQ csrc, csrc_tmp + MOVQ $0, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + SUBQ $2, csrc_tmp + MOVQ -8(src_tmp)(len*1), src_val0 + +next_vect_8b: + ADDQ $24, d2src_off + MOVQ (d2src)(d2src_off*1), src_tmp + MOVQ -8(src_tmp)(len*1), src_val1 + XORQ src_val1, src_val0 + SUBQ $1, csrc_tmp + JGE next_vect_8b + + MOVQ src_val0, -8(dst)(len*1) + SUBQ $8, len + SUBQ $8, not_aligned_len + JG loop_8b + + CMPQ len, $64 + JGE aligned + RET + +ret: + RET diff --git a/vendor/github.com/tjfoc/gmsm/.gitignore b/vendor/github.com/tjfoc/gmsm/.gitignore new file mode 100644 index 00000000..48d6c952 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/.gitignore @@ -0,0 +1,34 @@ + +#Ignore thumbnails created by Windows +Thumbs.db +#Ignore files built by Visual Studio +*.obj +*.exe +*.pdb +*.user +*.aps +*.pch +*.vspscc +*_i.c +*_p.c +*.ncb +*.suo +*.tlb +*.tlh +*.bak +*.cache +*.ilk +*.log +[Bb]in +[Dd]ebug*/ +*.lib +*.sbr +obj/ +[Rr]elease*/ +_ReSharper*/ +[Tt]est[Rr]esult* +.vs/ +#Nuget packages folder +packages/ +*.pem +.idea \ No newline at end of file diff --git a/vendor/github.com/tjfoc/gmsm/.travis.yml b/vendor/github.com/tjfoc/gmsm/.travis.yml new file mode 100644 index 00000000..6621a877 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/.travis.yml @@ -0,0 +1,35 @@ +sudo: false +dist: bionic +language: go +os: + - linux + - osx +osx_image: xcode11 +go: + - 1.14.x + - 1.15.x +before_install: + - export GO111MODULE=on +install: + - go get -u golang.org/x/lint/golint + - export golint=$(go list -f {{.Target}} golang.org/x/lint/golint) + - export GODEBUG=x509ignoreCN=0 + - go mod vendor + - go build -v ./sm2 + - go build -v ./sm3 + - go build -v ./sm4 + - go build -v ./x509 + - go build -v ./pkcs12 + - go build -v ./gmtls/gmcredentials + - go build -v ./gmtls/gmcredentials/echo + - go build -v ./gmtls/websvr +script: + - go vet ./sm2 + - go vet ./sm3 + - go vet ./sm4 + - go vet ./x509 + - go vet ./pkcs12 + - go vet ./gmtls/gmcredentials + - go vet ./gmtls/websvr + - golint . + - go test -v ./... diff --git "a/vendor/github.com/tjfoc/gmsm/API\344\275\277\347\224\250\350\257\264\346\230\216.md" "b/vendor/github.com/tjfoc/gmsm/API\344\275\277\347\224\250\350\257\264\346\230\216.md" new file mode 100644 index 00000000..4a8d414c --- /dev/null +++ "b/vendor/github.com/tjfoc/gmsm/API\344\275\277\347\224\250\350\257\264\346\230\216.md" @@ -0,0 +1,132 @@ +# 国密GM/T Go API使用说明 + + + +- [国密GM/T Go API使用说明](#国密gmt-go-api使用说明) + - [Go包安装](#go包安装) + - [SM2椭圆曲线公钥密码算法](#sm2椭圆曲线公钥密码算法) + - [代码示例](#代码示例) + - [SM3密码杂凑算法](#sm3密码杂凑算法) + - [代码示例](#代码示例-1) + - [SM4分组密码算法](#sm4分组密码算法) + - [代码示例](#代码示例-2) + - [相关代码示例参考](#相关代码示例参考) + - [国密SSL(TLCP)](#国密ssltlcp) + + + +## Go包安装 + +```bash +go get -u github.com/tjfoc/gmsm +``` +## SM2椭圆曲线公钥密码算法 + +> SM2椭圆曲线公钥密码算法 Public key cryptographic algorithm SM2 based on elliptic curves + +- 遵循的SM2标准号为: GM/T 0003.1-2012、GM/T 0003.2-2012、GM/T 0003.3-2012、GM/T 0003.4-2012、GM/T 0003.5-2012、GM/T 0009-2012、GM/T 0010-2012 +- go package: `github.com/tjfoc/gmsm/sm2` + +### 代码示例 + +```Go + priv, err := sm2.GenerateKey(rand.Reader) // 生成密钥对 + if err != nil { + log.Fatal(err) + } + msg := []byte("Tongji Fintech Research Institute") + pub := &priv.PublicKey + ciphertxt, err := pub.EncryptAsn1(msg,rand.Reader) //sm2加密 + if err != nil { + log.Fatal(err) + } + fmt.Printf("加密结果:%x\n",ciphertxt) + plaintxt,err := priv.DecryptAsn1(ciphertxt) //sm2解密 + if err != nil { + log.Fatal(err) + } + if !bytes.Equal(msg,plaintxt){ + log.Fatal("原文不匹配") + } + + sign,err := priv.Sign(rand.Reader, msg, nil) //sm2签名 + if err != nil { + log.Fatal(err) + } + isok := pub.Verify(msg, sign) //sm2验签 + fmt.Printf("Verified: %v\n", isok) +``` +## SM3密码杂凑算法 + +> SM3密码杂凑算法 - SM3 cryptographic hash algorithm + +- 遵循的SM3标准号为: GM/T 0004-2012 +- g package:`github.com/tjfoc/gmsm/sm3` +- `type SM3 struct` 是原生接口hash.Hash的一个实现 + +### 代码示例 + +```Go + data := "test" + h := sm3.New() + h.Write([]byte(data)) + sum := h.Sum(nil) + fmt.Printf("digest value is: %x\n",sum) +``` + +## SM4分组密码算法 + +> SM4分组密码算法 - SM4 block cipher algorithm + +- 遵循的SM4标准号为: GM/T 0002-2012 +- go package:`github.com/tjfoc/gmsm/sm4` + +### 代码示例 + +```Go + import "crypto/cipher" + import "github.com/tjfoc/gmsm/sm4" + import "fmt" + + func main(){ + key := []byte("1234567890abcdef") + fmt.Printf("key = %v\n", key) + data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} + fmt.Printf("key = %v\n", key) + fmt.Printf("data = %x\n", data) + iv := []byte("0000000000000000") + err = SetIV(iv)//设置SM4算法实现的IV值,不设置则使用默认值 + ecbMsg, err :=sm4.Sm4Ecb(key, data, true) //sm4Ecb模式pksc7填充加密 + if err != nil { + t.Errorf("sm4 enc error:%s", err) + return + } + fmt.Printf("ecbMsg = %x\n", ecbMsg) + ecbDec, err := sm4.Sm4Ecb(key, ecbMsg, false) //sm4Ecb模式pksc7填充解密 + if err != nil { + t.Errorf("sm4 dec error:%s", err) + return + } + fmt.Printf("ecbDec = %x\n", ecbDec) + } +``` + + +## 相关代码示例参考 + +- [SM2算法 sm2/sm2_test.go](sm2/sm2_test.go) +- [SM3算法 sm3/sm3_test.go](sm3/sm3_test.go) +- [SM4算法 sm4/sm4_test.go](sm4/sm4_test.go) +- [x509国密证书 x509/x509_test.go](x509/x509_test.go) + +## 国密SSL(TLCP) + +- 国密SSL协议遵循标准:《GM/T 0024-2014 SSL VPN技术规范》 +- **国密SSL协议目前升级为TLCP协议,遵循《GBT 38636-2020 信息安全技术 传输层密码协议》**,新增了SM4的GCM加密模式。 + +国密SSL使用详情见文档: [《tjfoc 国密SSL协议快速入门》](gmtls/websvr/README.md) + +示例入口: + +- [国密HTTPS Web服务器测试用例 gmtls/websvr/websvr.go](gmtls/websvr/websvr.go) +- [国密TLS GRPC测试用例 gmtls/websvr/credentials_test.go](gmtls/gmcredentials/credentials_test.go) diff --git a/vendor/github.com/tjfoc/gmsm/CHANGELOG.md b/vendor/github.com/tjfoc/gmsm/CHANGELOG.md new file mode 100644 index 00000000..bc56f7a3 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/CHANGELOG.md @@ -0,0 +1,85 @@ +## 更新日志 +### 2.0 更新(June 9,2021) +- [FIX] SM2公钥压缩格式前缀修改 +- [FIX]]国密tls部分bug修改 +- [FIX]]SM4_gcm模式bug修改 +- [New] 公私钥16进制格式导入导出 +- [New] SM4加密可设置IV值 +- [New] 国密tls与非国密自适应实现 + + +### 1.4 更新(Sep 15,2020) +**破坏性更新** +- [FIX] SM2 生成私钥、签名及加密方法加入随机数,可使用外部随机数。 +- [FIX] 公私钥及证书从pem格式[]byte数据中导入导出,不再从pem格式File文件导入导出。 +- [FIX] SM2证书相关实现代码单独移至X509包中。 +- [FIX] 代码优化,删除无用方法 +- [New] SM4加密ecb、cbc模式实现 +- [New] 国密tls实现移至该库gmtls包中 +- [New] 国密tls实现使用双证书 + +### 1.2 更新(Feb 20, 2019) + +- [NEW] 实现PKCS#7签名及验签 +- [FIX] SM2 签名及验签方法完全遵循标准GM/T 0003系列,兼容CFCA Java-SDK + +**破坏性更新,证书与之前版本不兼容** + +### 1.1.1更新 +- 新增以下函数支持用户其他信息
+ SignDigitToSignData 将签名所得的大数r和s转换为签名的格式
+ Sm2Sign 支持用户信息的签名
+ Sm2Verify 支持用户信息的验签
+ + +### 1.1.0更新: +- 改进新能,具体提升如下 + 注:本次优化并不彻底,只是第一次尝试优化,后续有时间还会继续优化 +``` + old: + generate key: + BenchmarkSM2-4 1000 2517147 ns/op 1156476 B/op 11273 allocs/op + sign: + BenchmarkSM2-4 300 6297498 ns/op 2321890 B/op 22653 allocs/op + verify: + BenchmarkSM2-4 2000 8557215 ns/op 3550626 B/op 34627 allocs/op + encrypt: + BenchmarkSM2-4 2000 8304840 ns/op 3483113 B/op 33967 allocs/op + decrypt: + BenchmarkSM2-4 2000 5726181 ns/op 2321728 B/op 22644 allocs/op + new: + generate key: + BenchmarkSM2-4 5000 303656 ns/op 2791 B/op 41 allocs/op + sign: + BenchmarkSM2-4 2000 652465 ns/op 8828 B/op 133 allocs/op + verify: + BenchmarkSM2-4 1000 2004511 ns/op 122709 B/op 1738 allocs/op + encrpyt: + BenchmarkSM2-4 1000 1984419 ns/op 118560 B/op 1687 allocs/op + decrypt: + BenchmarkSM2-4 1000 1725001 ns/op 118331 B/op 1679 allocs/op +``` + +### 1.0.1 更新: +- 添加全局的sbox改进sm4效率(by https://github.com/QwertyJack) + + +### 1.0 更新: +- 添加以下oid
+ SM3WithSM2 1.2.156.10197.1.501
+ SHA1WithSM2 1.2.156.10197.1.502
+ SHA256WithSM2 1.2.156.10197.1.503
+ +- x509生成的证书如今可以使用SM3作为hash算法 + +- 引入了以下hash算法 + RIPEMD160
+ SHA3_256
+ SHA3_384
+ SHA3_512
+ SHA3_SM3
+ 用户需要自己安装golang.org/x/crypto + + + + diff --git a/vendor/github.com/tjfoc/gmsm/LICENSE b/vendor/github.com/tjfoc/gmsm/LICENSE new file mode 100644 index 00000000..8dada3ed --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/tjfoc/gmsm/README.md b/vendor/github.com/tjfoc/gmsm/README.md new file mode 100644 index 00000000..06319e13 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/README.md @@ -0,0 +1,58 @@ + +# gmsm +GM SM2/3/4 library based on Golang +======= + +[![Build Status](https://travis-ci.com/tjfoc/gmsm.svg?branch=master)](https://travis-ci.com/github/tjfoc/gmsm) + + +## Feature + gmsm包含以下主要功能 + + SM2: 国密椭圆曲线算法库 + . 支持Generate Key, Sign, Verify基础操作 + . 支持加密和不加密的pem文件格式(加密方法参见RFC5958, 具体实现参加代码) + . 支持证书的生成,证书的读写(接口兼容rsa和ecdsa的证书) + . 支持证书链的操作(接口兼容rsa和ecdsa) + . 支持crypto.Signer接口 + + SM3: 国密hash算法库 + . 支持基础的sm3Sum操作 + . 支持hash.Hash接口 + + SM4: 国密分组密码算法库 + . 支持Generate Key, Encrypt, Decrypt基础操作 + . 提供Cipher.Block接口 + . 支持加密和不加密的pem文件格式(加密方法为pem block加密, 具体函数为x509.EncryptPEMBlock) + +## [Usage 使用说明](./API使用说明.md) + +## Communication +tjfoc国密交流 + +[![Join the chat at https://gitter.im/tjfoc/gmsm](https://badges.gitter.im/tjfoc/gmsm.svg)](https://gitter.im/tjfoc/gmsm?utm_source=badge&utm_medium=badge&utm_campaign=-badge&utm_content=badge) + + +- 如果你对国密算法开源技术及应用感兴趣,欢迎添加“苏州同济区块链研究院·小助手“微信,回复“国密算法进群”,加入“同济区块链国密算法交流群”。微信二维码如下: + ![微信二维码](https://github.com/tjfoc/wutongchian-public/blob/master/wutongchain.png) + +- 发送邮件到tj@wutongchain.com + + + ## License + 版权所有 苏州同济区块链研究院有限公司(http://www.wutongchain.com/) + + Copyright 2017- Suzhou Tongji Fintech Research Institute. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + + See the License for the specific language governing permissions and limitations under the License. +======= + + + + diff --git a/vendor/github.com/tjfoc/gmsm/azure-pipelines.yml b/vendor/github.com/tjfoc/gmsm/azure-pipelines.yml new file mode 100644 index 00000000..b7c77279 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/azure-pipelines.yml @@ -0,0 +1,34 @@ +pool: + vmImage: 'ubuntu-18.04' + +strategy: + matrix: + LTS: + goVersion: '1.15' + latest: + goVersion: '1.14' + +steps: + - task: GoTool@0 + inputs: + version: $(goVersion) + - script: export GODEBUG=x509ignoreCN=0 + - script: go build -v ./sm2 + - script: go build -v ./sm3 + - script: go build -v ./sm4 + - script: go build -v ./x509 + - script: go build -v ./pkcs12 + - script: go build -v ./gmtls/gmcredentials + - script: go build -v ./gmtls/gmcredentials/echo + - script: go build -v ./gmtls/websvr + - script: go mod vendor + - script: go vet ./sm2 + - script: go vet ./sm3 + - script: go vet ./sm4 + - script: go vet ./x509 + - script: go vet ./pkcs12 + - script: go vet ./gmtls/gmcredentials + - script: go vet ./gmtls/websvr + - script: go test -v ./... + displayName: go test recursive + diff --git a/vendor/github.com/tjfoc/gmsm/go.mod b/vendor/github.com/tjfoc/gmsm/go.mod new file mode 100644 index 00000000..380ab0d6 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/go.mod @@ -0,0 +1,10 @@ +module github.com/tjfoc/gmsm + +go 1.14 + +require ( + github.com/golang/protobuf v1.4.2 + golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee + golang.org/x/net v0.0.0-20201010224723-4f7140c49acb + google.golang.org/grpc v1.31.0 +) diff --git a/vendor/github.com/tjfoc/gmsm/go.sum b/vendor/github.com/tjfoc/gmsm/go.sum new file mode 100644 index 00000000..2caba79b --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/go.sum @@ -0,0 +1,80 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee h1:4yd7jl+vXjalO5ztz6Vc1VADv+S/80LGJmyl1ROJ2AI= +golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb h1:mUVeFHoDKis5nxCAzoAi7E8Ghb86EXh/RK6wtvJIqRY= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/vendor/github.com/tjfoc/gmsm/sm4/padding/README.md b/vendor/github.com/tjfoc/gmsm/sm4/padding/README.md new file mode 100644 index 00000000..3b371f17 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/padding/README.md @@ -0,0 +1,82 @@ +# PCSK#7 填充读写流 + +直接使用分块加密,往往需要手动完成填充和去填充的过程。流程较为固定,但是比较繁琐,对于文件和流的处理,不是很友好。 + +PKCS#7填充和去除填充: + +- `padding.PKCS7PaddingReader` 写入结束是添加填充 +- `padding.PKCS7PaddingWriter` 读取时去除填充 + +封装的填充模式加密: + +- `padding.P7BlockEnc` +- `padding.P7BlockDecrypt` + +以上方法简化了,对文件和流类型的加密解密过程。 + +## 带填充的模式加密和解密 + +> 流:实现了 `io.Writer`、`io.Reader`接口的类型。 + +流的SM4分块CBC模式加密: + +```go +func main() { + src := bytes.Repeat([]byte{7}, 16) + srcIn := bytes.NewBuffer(src) + encOut := bytes.NewBuffer(make([]byte, 0, 1024)) + + key := make([]byte, 16) + iv := make([]byte, 16) + _, _ = rand.Read(key) + _, _ = rand.Read(iv) + fmt.Printf("key: %02X\n", key) + fmt.Printf("iv : %02X\n", iv) + block, err := sm4.NewCipher(key) + if err != nil { + panic(err) + } + encrypter := cipher.NewCBCEncrypter(block, iv) + // P7填充的CBC加密 + err = padding.P7BlockEnc(encrypter, srcIn, encOut) + if err != nil { + panic(err) + } + fmt.Printf("原文: %02X\n", src) + fmt.Printf("加密: %02X\n", encOut.Bytes()) +} +``` + +流的SM4分块CBC模式解密: + +```go +func main() { + /** + key: 4C9CA3D17263F6F558D65ADB561465BD + iv : 221908D1C4BD730BEB011319D1368E49 + 原文: 07070707070707070707070707070707 + 加密: 310CA2472DCE15CCC58E1BE69B876002F443556CCFB86B1BA0341B6BFBED4C1A + */ + encOut := bytes.NewBuffer(make([]byte, 0, 1024)) + key,_ := hex.DecodeString("4C9CA3D17263F6F558D65ADB561465BD") + iv,_ := hex.DecodeString("221908D1C4BD730BEB011319D1368E49") + block, err := sm4.NewCipher(key) + if err != nil { + panic(err) + } + ciphertext, _ := hex.DecodeString("310CA2472DCE15CCC58E1BE69B876002F443556CCFB86B1BA0341B6BFBED4C1A") + cipherReader := bytes.NewReader(ciphertext) + decrypter := cipher.NewCBCDecrypter(block, iv) + decOut := bytes.NewBuffer(make([]byte, 0, 1024)) + err = padding.P7BlockDecrypt(decrypter, ciphertext, decOut) + if err != nil { + panic(err) + } + + fmt.Printf("解密: %02X\n", decOut.Bytes()) +} +``` + +## PKCS#7填充 + +见测试用例: [pkcs7_padding_io_test.go](./pkcs7_padding_io_test.go) diff --git a/vendor/github.com/tjfoc/gmsm/sm4/padding/bloc_cryptor.go b/vendor/github.com/tjfoc/gmsm/sm4/padding/bloc_cryptor.go new file mode 100644 index 00000000..15d4eaf1 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/padding/bloc_cryptor.go @@ -0,0 +1,56 @@ +package padding + +import ( + "crypto/cipher" + "io" +) + +// P7BlockDecrypt 解密密文,并去除PKCS#7填充 +// decrypter: 块解密器 +// in: 密文输入流 +// out: 明文输出流 +func P7BlockDecrypt(decrypter cipher.BlockMode, in io.Reader, out io.Writer) error { + bufIn := make([]byte, 1024) + bufOut := make([]byte, 1024) + p7Out := NewPKCS7PaddingWriter(out, decrypter.BlockSize()) + for { + n, err := in.Read(bufIn) + if err != nil && err != io.EOF { + return err + } + if n == 0 { + break + } + decrypter.CryptBlocks(bufOut, bufIn[:n]) + _, err = p7Out.Write(bufOut[:n]) + if err != nil { + return err + } + } + return p7Out.Final() +} + +// P7BlockEnc 以PKCS#7填充模式填充原文,并加密输出 +// encrypter: 块加密器 +// in: 明文输入流 +// out: 密文输出流 +func P7BlockEnc(encrypter cipher.BlockMode, in io.Reader, out io.Writer) error { + bufIn := make([]byte, 1024) + bufOut := make([]byte, 1024) + p7In := NewPKCS7PaddingReader(in, encrypter.BlockSize()) + for { + n, err := p7In.Read(bufIn) + if err != nil && err != io.EOF { + return err + } + if n == 0 { + break + } + encrypter.CryptBlocks(bufOut, bufIn[:n]) + _, err = out.Write(bufOut[:n]) + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/tjfoc/gmsm/sm4/padding/bloc_cryptor_test.go b/vendor/github.com/tjfoc/gmsm/sm4/padding/bloc_cryptor_test.go new file mode 100644 index 00000000..10570c78 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/padding/bloc_cryptor_test.go @@ -0,0 +1,48 @@ +package padding + +import ( + "bytes" + "crypto/cipher" + "crypto/rand" + "fmt" + "github.com/tjfoc/gmsm/sm4" + "testing" +) + +func TestP7BlockDecrypt(t *testing.T) { + src := bytes.Repeat([]byte{7}, 16) + + srcIn := bytes.NewBuffer(src) + encOut := bytes.NewBuffer(make([]byte, 0, 1024)) + + key := make([]byte, 16) + iv := make([]byte, 16) + _, _ = rand.Read(key) + _, _ = rand.Read(iv) + fmt.Printf("key: %02X\n", key) + fmt.Printf("iv : %02X\n", iv) + block, err := sm4.NewCipher(key) + if err != nil { + t.Fatal(err) + } + encrypter := cipher.NewCBCEncrypter(block, iv) + + err = P7BlockEnc(encrypter, srcIn, encOut) + if err != nil { + t.Fatal(err) + } + fmt.Printf("原文: %02X\n", src) + fmt.Printf("加密: %02X\n", encOut.Bytes()) + + decrypter := cipher.NewCBCDecrypter(block, iv) + decOut := bytes.NewBuffer(make([]byte, 0, 1024)) + err = P7BlockDecrypt(decrypter, encOut, decOut) + if err != nil { + t.Fatal(err) + } + + fmt.Printf("解密: %02X\n", decOut.Bytes()) + if !bytes.Equal(src, decOut.Bytes()) { + t.Fatalf("实际解密结果: %02X, 期待结果: %02X", decOut.Bytes(), src) + } +} diff --git a/vendor/github.com/tjfoc/gmsm/sm4/padding/pkcs7_padding_io.go b/vendor/github.com/tjfoc/gmsm/sm4/padding/pkcs7_padding_io.go new file mode 100644 index 00000000..1826ed44 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/padding/pkcs7_padding_io.go @@ -0,0 +1,144 @@ +package padding + +import ( + "bytes" + "errors" + "io" +) + +// PKCS7PaddingReader 符合PKCS#7填充的输入流 +type PKCS7PaddingReader struct { + fIn io.Reader + padding io.Reader + blockSize int + readed int64 + eof bool + eop bool +} + +// NewPKCS7PaddingReader 创建PKCS7填充Reader +// in: 输入流 +// blockSize: 分块大小 +func NewPKCS7PaddingReader(in io.Reader, blockSize int) *PKCS7PaddingReader { + return &PKCS7PaddingReader{ + fIn: in, + padding: nil, + eof: false, + eop: false, + blockSize: blockSize, + } +} + +func (p *PKCS7PaddingReader) Read(buf []byte) (int, error) { + /* + - 读取文件 + - 文件长度充足, 直接返还 + - 不充足 + - 读取到 n 字节, 剩余需要 m 字节 + - 从 padding 中读取然后追加到 buff + - EOF 直接返回, 整个Reader end + */ + // 都读取完了 + if p.eof && p.eop { + return 0, io.EOF + } + + var n, off = 0, 0 + var err error + if !p.eof { + // 读取文件 + n, err = p.fIn.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + // 错误返回 + return 0, err + } + p.readed += int64(n) + if errors.Is(err, io.EOF) { + // 标志文件结束 + p.eof = true + } + if n == len(buf) { + // 长度足够直接返回 + return n, nil + } + // 文件长度已经不足,根据已经已经读取的长度创建Padding + p.newPadding() + // 长度不足向Padding中索要 + off = n + } + + if !p.eop { + // 读取流 + var n2 = 0 + n2, err = p.padding.Read(buf[off:]) + n += n2 + if errors.Is(err, io.EOF) { + p.eop = true + } + } + return n, err +} + +// 新建Padding +func (p *PKCS7PaddingReader) newPadding() { + if p.padding != nil { + return + } + size := p.blockSize - int(p.readed%int64(p.blockSize)) + padding := bytes.Repeat([]byte{byte(size)}, size) + p.padding = bytes.NewReader(padding) +} + +// PKCS7PaddingWriter 符合PKCS#7去除的输入流,最后一个 分组根据会根据填充情况去除填充。 +type PKCS7PaddingWriter struct { + cache *bytes.Buffer // 缓存区 + swap []byte // 临时交换区 + out io.Writer // 输出位置 + blockSize int // 分块大小 +} + +// NewPKCS7PaddingWriter PKCS#7 填充Writer 可以去除填充 +func NewPKCS7PaddingWriter(out io.Writer, blockSize int) *PKCS7PaddingWriter { + cache := bytes.NewBuffer(make([]byte, 0, 1024)) + swap := make([]byte, 1024) + return &PKCS7PaddingWriter{out: out, blockSize: blockSize, cache: cache, swap: swap} +} + +// Write 保留一个填充大小的数据,其余全部写入输出中 +func (p *PKCS7PaddingWriter) Write(buff []byte) (n int, err error) { + // 写入缓存 + n, err = p.cache.Write(buff) + if err != nil { + return 0, err + } + if p.cache.Len() > p.blockSize { + // 把超过一个分组长度的部分读取出来,写入到实际的out中 + size := p.cache.Len() - p.blockSize + _, _ = p.cache.Read(p.swap[:size]) + _, err = p.out.Write(p.swap[:size]) + if err != nil { + return 0, err + } + } + return n, err + +} + +// Final 去除填充写入最后一个分块 +func (p *PKCS7PaddingWriter) Final() error { + // 在Write 之后 cache 只会保留一个Block长度数据 + b := p.cache.Bytes() + length := len(b) + if length != p.blockSize { + return errors.New("非法的PKCS7填充") + } + if length == 0 { + return nil + } + unpadding := int(b[length-1]) + if unpadding > p.blockSize || unpadding == 0 { + return errors.New("非法的PKCS7填充") + } + _, err := p.out.Write(b[:(length - unpadding)]) + return err +} diff --git a/vendor/github.com/tjfoc/gmsm/sm4/padding/pkcs7_padding_io_test.go b/vendor/github.com/tjfoc/gmsm/sm4/padding/pkcs7_padding_io_test.go new file mode 100644 index 00000000..8308d9f0 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/padding/pkcs7_padding_io_test.go @@ -0,0 +1,72 @@ +package padding + +import ( + "bytes" + "io" + "testing" +) + +// 测试P7填充Reader +func TestPaddingFileReader_Read(t *testing.T) { + srcIn := bytes.NewBuffer(bytes.Repeat([]byte{'A'}, 16)) + p := NewPKCS7PaddingReader(srcIn, 16) + + tests := []struct { + name string + buf []byte + want int + wantErr error + }{ + {"读取文件 1B", make([]byte, 1), 1, nil}, + {"交叉读取 15B 文件 1B", make([]byte, 16), 16, nil}, + {"填充读取 3B", make([]byte, 3), 3, nil}, + {"超过填充读取 16B", make([]byte, 16), 12, nil}, + {"文件结束 16B", make([]byte, 16), 0, io.EOF}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := p.Read(tt.buf) + if err != tt.wantErr { + t.Errorf("Read() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Read() 读取到了 = %v, 但是需要 %v", got, tt.want) + } + }) + } +} + +// 测试P7填充Writer +func TestPKCS7PaddingWriter_Write(t *testing.T) { + src := []byte{ + 0, 1, 2, 3, 4, 5, 6, 7, + } + paddedSrc := append(src, bytes.Repeat([]byte{0x08}, 8)...) + reader := bytes.NewReader(paddedSrc) + out := bytes.NewBuffer(make([]byte, 0, 64)) + writer := NewPKCS7PaddingWriter(out, 8) + + for { + buf := make([]byte, 3) + n, err := reader.Read(buf) + if err != nil && err != io.EOF { + t.Fatal(err) + } + if n == 0 { + break + } + _, err = writer.Write(buf[:n]) + if err != nil { + t.Fatal(err) + } + } + err := writer.Final() + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(out.Bytes(), src) { + t.Fatalf("去除填充后实际为 %02X,期待去除填充之后的结果为 %02X", out.Bytes(), src) + } +} diff --git a/vendor/github.com/tjfoc/gmsm/sm4/sm4.go b/vendor/github.com/tjfoc/gmsm/sm4/sm4.go new file mode 100644 index 00000000..0e301deb --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/sm4.go @@ -0,0 +1,492 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +sm4 acceleration +modified by Jack, 2017 Oct +*/ + +package sm4 + +import ( + "bytes" + "crypto/cipher" + "errors" + "strconv" +) + +const BlockSize = 16 + +var IV = make([]byte, BlockSize) + +type SM4Key []byte + +// Cipher is an instance of SM4 encryption. +type Sm4Cipher struct { + subkeys []uint32 + block1 []uint32 + block2 []byte +} + +// sm4密钥参量 +var fk = [4]uint32{ + 0xa3b1bac6, 0x56aa3350, 0x677d9197, 0xb27022dc, +} + +// sm4密钥参量 +var ck = [32]uint32{ + 0x00070e15, 0x1c232a31, 0x383f464d, 0x545b6269, + 0x70777e85, 0x8c939aa1, 0xa8afb6bd, 0xc4cbd2d9, + 0xe0e7eef5, 0xfc030a11, 0x181f262d, 0x343b4249, + 0x50575e65, 0x6c737a81, 0x888f969d, 0xa4abb2b9, + 0xc0c7ced5, 0xdce3eaf1, 0xf8ff060d, 0x141b2229, + 0x30373e45, 0x4c535a61, 0x686f767d, 0x848b9299, + 0xa0a7aeb5, 0xbcc3cad1, 0xd8dfe6ed, 0xf4fb0209, + 0x10171e25, 0x2c333a41, 0x484f565d, 0x646b7279, +} + +// sm4密钥参量 +var sbox = [256]uint8{ + 0xd6, 0x90, 0xe9, 0xfe, 0xcc, 0xe1, 0x3d, 0xb7, 0x16, 0xb6, 0x14, 0xc2, 0x28, 0xfb, 0x2c, 0x05, + 0x2b, 0x67, 0x9a, 0x76, 0x2a, 0xbe, 0x04, 0xc3, 0xaa, 0x44, 0x13, 0x26, 0x49, 0x86, 0x06, 0x99, + 0x9c, 0x42, 0x50, 0xf4, 0x91, 0xef, 0x98, 0x7a, 0x33, 0x54, 0x0b, 0x43, 0xed, 0xcf, 0xac, 0x62, + 0xe4, 0xb3, 0x1c, 0xa9, 0xc9, 0x08, 0xe8, 0x95, 0x80, 0xdf, 0x94, 0xfa, 0x75, 0x8f, 0x3f, 0xa6, + 0x47, 0x07, 0xa7, 0xfc, 0xf3, 0x73, 0x17, 0xba, 0x83, 0x59, 0x3c, 0x19, 0xe6, 0x85, 0x4f, 0xa8, + 0x68, 0x6b, 0x81, 0xb2, 0x71, 0x64, 0xda, 0x8b, 0xf8, 0xeb, 0x0f, 0x4b, 0x70, 0x56, 0x9d, 0x35, + 0x1e, 0x24, 0x0e, 0x5e, 0x63, 0x58, 0xd1, 0xa2, 0x25, 0x22, 0x7c, 0x3b, 0x01, 0x21, 0x78, 0x87, + 0xd4, 0x00, 0x46, 0x57, 0x9f, 0xd3, 0x27, 0x52, 0x4c, 0x36, 0x02, 0xe7, 0xa0, 0xc4, 0xc8, 0x9e, + 0xea, 0xbf, 0x8a, 0xd2, 0x40, 0xc7, 0x38, 0xb5, 0xa3, 0xf7, 0xf2, 0xce, 0xf9, 0x61, 0x15, 0xa1, + 0xe0, 0xae, 0x5d, 0xa4, 0x9b, 0x34, 0x1a, 0x55, 0xad, 0x93, 0x32, 0x30, 0xf5, 0x8c, 0xb1, 0xe3, + 0x1d, 0xf6, 0xe2, 0x2e, 0x82, 0x66, 0xca, 0x60, 0xc0, 0x29, 0x23, 0xab, 0x0d, 0x53, 0x4e, 0x6f, + 0xd5, 0xdb, 0x37, 0x45, 0xde, 0xfd, 0x8e, 0x2f, 0x03, 0xff, 0x6a, 0x72, 0x6d, 0x6c, 0x5b, 0x51, + 0x8d, 0x1b, 0xaf, 0x92, 0xbb, 0xdd, 0xbc, 0x7f, 0x11, 0xd9, 0x5c, 0x41, 0x1f, 0x10, 0x5a, 0xd8, + 0x0a, 0xc1, 0x31, 0x88, 0xa5, 0xcd, 0x7b, 0xbd, 0x2d, 0x74, 0xd0, 0x12, 0xb8, 0xe5, 0xb4, 0xb0, + 0x89, 0x69, 0x97, 0x4a, 0x0c, 0x96, 0x77, 0x7e, 0x65, 0xb9, 0xf1, 0x09, 0xc5, 0x6e, 0xc6, 0x84, + 0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48, +} + +var sbox0 = [256]uint32{ + 0xd55b5b8e, 0x924242d0, 0xeaa7a74d, 0xfdfbfb06, 0xcf3333fc, 0xe2878765, 0x3df4f4c9, 0xb5dede6b, 0x1658584e, 0xb4dada6e, 0x14505044, 0xc10b0bca, 0x28a0a088, 0xf8efef17, 0x2cb0b09c, 0x05141411, + 0x2bacac87, 0x669d9dfb, 0x986a6af2, 0x77d9d9ae, 0x2aa8a882, 0xbcfafa46, 0x04101014, 0xc00f0fcf, 0xa8aaaa02, 0x45111154, 0x134c4c5f, 0x269898be, 0x4825256d, 0x841a1a9e, 0x0618181e, 0x9b6666fd, + 0x9e7272ec, 0x4309094a, 0x51414110, 0xf7d3d324, 0x934646d5, 0xecbfbf53, 0x9a6262f8, 0x7be9e992, 0x33ccccff, 0x55515104, 0x0b2c2c27, 0x420d0d4f, 0xeeb7b759, 0xcc3f3ff3, 0xaeb2b21c, 0x638989ea, + 0xe7939374, 0xb1cece7f, 0x1c70706c, 0xaba6a60d, 0xca2727ed, 0x08202028, 0xeba3a348, 0x975656c1, 0x82020280, 0xdc7f7fa3, 0x965252c4, 0xf9ebeb12, 0x74d5d5a1, 0x8d3e3eb3, 0x3ffcfcc3, 0xa49a9a3e, + 0x461d1d5b, 0x071c1c1b, 0xa59e9e3b, 0xfff3f30c, 0xf0cfcf3f, 0x72cdcdbf, 0x175c5c4b, 0xb8eaea52, 0x810e0e8f, 0x5865653d, 0x3cf0f0cc, 0x1964647d, 0xe59b9b7e, 0x87161691, 0x4e3d3d73, 0xaaa2a208, + 0x69a1a1c8, 0x6aadadc7, 0x83060685, 0xb0caca7a, 0x70c5c5b5, 0x659191f4, 0xd96b6bb2, 0x892e2ea7, 0xfbe3e318, 0xe8afaf47, 0x0f3c3c33, 0x4a2d2d67, 0x71c1c1b0, 0x5759590e, 0x9f7676e9, 0x35d4d4e1, + 0x1e787866, 0x249090b4, 0x0e383836, 0x5f797926, 0x628d8def, 0x59616138, 0xd2474795, 0xa08a8a2a, 0x259494b1, 0x228888aa, 0x7df1f18c, 0x3bececd7, 0x01040405, 0x218484a5, 0x79e1e198, 0x851e1e9b, + 0xd7535384, 0x00000000, 0x4719195e, 0x565d5d0b, 0x9d7e7ee3, 0xd04f4f9f, 0x279c9cbb, 0x5349491a, 0x4d31317c, 0x36d8d8ee, 0x0208080a, 0xe49f9f7b, 0xa2828220, 0xc71313d4, 0xcb2323e8, 0x9c7a7ae6, + 0xe9abab42, 0xbdfefe43, 0x882a2aa2, 0xd14b4b9a, 0x41010140, 0xc41f1fdb, 0x38e0e0d8, 0xb7d6d661, 0xa18e8e2f, 0xf4dfdf2b, 0xf1cbcb3a, 0xcd3b3bf6, 0xfae7e71d, 0x608585e5, 0x15545441, 0xa3868625, + 0xe3838360, 0xacbaba16, 0x5c757529, 0xa6929234, 0x996e6ef7, 0x34d0d0e4, 0x1a686872, 0x54555501, 0xafb6b619, 0x914e4edf, 0x32c8c8fa, 0x30c0c0f0, 0xf6d7d721, 0x8e3232bc, 0xb3c6c675, 0xe08f8f6f, + 0x1d747469, 0xf5dbdb2e, 0xe18b8b6a, 0x2eb8b896, 0x800a0a8a, 0x679999fe, 0xc92b2be2, 0x618181e0, 0xc30303c0, 0x29a4a48d, 0x238c8caf, 0xa9aeae07, 0x0d343439, 0x524d4d1f, 0x4f393976, 0x6ebdbdd3, + 0xd6575781, 0xd86f6fb7, 0x37dcdceb, 0x44151551, 0xdd7b7ba6, 0xfef7f709, 0x8c3a3ab6, 0x2fbcbc93, 0x030c0c0f, 0xfcffff03, 0x6ba9a9c2, 0x73c9c9ba, 0x6cb5b5d9, 0x6db1b1dc, 0x5a6d6d37, 0x50454515, + 0x8f3636b9, 0x1b6c6c77, 0xadbebe13, 0x904a4ada, 0xb9eeee57, 0xde7777a9, 0xbef2f24c, 0x7efdfd83, 0x11444455, 0xda6767bd, 0x5d71712c, 0x40050545, 0x1f7c7c63, 0x10404050, 0x5b696932, 0xdb6363b8, + 0x0a282822, 0xc20707c5, 0x31c4c4f5, 0x8a2222a8, 0xa7969631, 0xce3737f9, 0x7aeded97, 0xbff6f649, 0x2db4b499, 0x75d1d1a4, 0xd3434390, 0x1248485a, 0xbae2e258, 0xe6979771, 0xb6d2d264, 0xb2c2c270, + 0x8b2626ad, 0x68a5a5cd, 0x955e5ecb, 0x4b292962, 0x0c30303c, 0x945a5ace, 0x76ddddab, 0x7ff9f986, 0x649595f1, 0xbbe6e65d, 0xf2c7c735, 0x0924242d, 0xc61717d1, 0x6fb9b9d6, 0xc51b1bde, 0x86121294, + 0x18606078, 0xf3c3c330, 0x7cf5f589, 0xefb3b35c, 0x3ae8e8d2, 0xdf7373ac, 0x4c353579, 0x208080a0, 0x78e5e59d, 0xedbbbb56, 0x5e7d7d23, 0x3ef8f8c6, 0xd45f5f8b, 0xc82f2fe7, 0x39e4e4dd, 0x49212168, +} + +var sbox1 = [256]uint32{ + 0x5b5b8ed5, 0x4242d092, 0xa7a74dea, 0xfbfb06fd, 0x3333fccf, 0x878765e2, 0xf4f4c93d, 0xdede6bb5, 0x58584e16, 0xdada6eb4, 0x50504414, 0x0b0bcac1, 0xa0a08828, 0xefef17f8, 0xb0b09c2c, 0x14141105, + 0xacac872b, 0x9d9dfb66, 0x6a6af298, 0xd9d9ae77, 0xa8a8822a, 0xfafa46bc, 0x10101404, 0x0f0fcfc0, 0xaaaa02a8, 0x11115445, 0x4c4c5f13, 0x9898be26, 0x25256d48, 0x1a1a9e84, 0x18181e06, 0x6666fd9b, + 0x7272ec9e, 0x09094a43, 0x41411051, 0xd3d324f7, 0x4646d593, 0xbfbf53ec, 0x6262f89a, 0xe9e9927b, 0xccccff33, 0x51510455, 0x2c2c270b, 0x0d0d4f42, 0xb7b759ee, 0x3f3ff3cc, 0xb2b21cae, 0x8989ea63, + 0x939374e7, 0xcece7fb1, 0x70706c1c, 0xa6a60dab, 0x2727edca, 0x20202808, 0xa3a348eb, 0x5656c197, 0x02028082, 0x7f7fa3dc, 0x5252c496, 0xebeb12f9, 0xd5d5a174, 0x3e3eb38d, 0xfcfcc33f, 0x9a9a3ea4, + 0x1d1d5b46, 0x1c1c1b07, 0x9e9e3ba5, 0xf3f30cff, 0xcfcf3ff0, 0xcdcdbf72, 0x5c5c4b17, 0xeaea52b8, 0x0e0e8f81, 0x65653d58, 0xf0f0cc3c, 0x64647d19, 0x9b9b7ee5, 0x16169187, 0x3d3d734e, 0xa2a208aa, + 0xa1a1c869, 0xadadc76a, 0x06068583, 0xcaca7ab0, 0xc5c5b570, 0x9191f465, 0x6b6bb2d9, 0x2e2ea789, 0xe3e318fb, 0xafaf47e8, 0x3c3c330f, 0x2d2d674a, 0xc1c1b071, 0x59590e57, 0x7676e99f, 0xd4d4e135, + 0x7878661e, 0x9090b424, 0x3838360e, 0x7979265f, 0x8d8def62, 0x61613859, 0x474795d2, 0x8a8a2aa0, 0x9494b125, 0x8888aa22, 0xf1f18c7d, 0xececd73b, 0x04040501, 0x8484a521, 0xe1e19879, 0x1e1e9b85, + 0x535384d7, 0x00000000, 0x19195e47, 0x5d5d0b56, 0x7e7ee39d, 0x4f4f9fd0, 0x9c9cbb27, 0x49491a53, 0x31317c4d, 0xd8d8ee36, 0x08080a02, 0x9f9f7be4, 0x828220a2, 0x1313d4c7, 0x2323e8cb, 0x7a7ae69c, + 0xabab42e9, 0xfefe43bd, 0x2a2aa288, 0x4b4b9ad1, 0x01014041, 0x1f1fdbc4, 0xe0e0d838, 0xd6d661b7, 0x8e8e2fa1, 0xdfdf2bf4, 0xcbcb3af1, 0x3b3bf6cd, 0xe7e71dfa, 0x8585e560, 0x54544115, 0x868625a3, + 0x838360e3, 0xbaba16ac, 0x7575295c, 0x929234a6, 0x6e6ef799, 0xd0d0e434, 0x6868721a, 0x55550154, 0xb6b619af, 0x4e4edf91, 0xc8c8fa32, 0xc0c0f030, 0xd7d721f6, 0x3232bc8e, 0xc6c675b3, 0x8f8f6fe0, + 0x7474691d, 0xdbdb2ef5, 0x8b8b6ae1, 0xb8b8962e, 0x0a0a8a80, 0x9999fe67, 0x2b2be2c9, 0x8181e061, 0x0303c0c3, 0xa4a48d29, 0x8c8caf23, 0xaeae07a9, 0x3434390d, 0x4d4d1f52, 0x3939764f, 0xbdbdd36e, + 0x575781d6, 0x6f6fb7d8, 0xdcdceb37, 0x15155144, 0x7b7ba6dd, 0xf7f709fe, 0x3a3ab68c, 0xbcbc932f, 0x0c0c0f03, 0xffff03fc, 0xa9a9c26b, 0xc9c9ba73, 0xb5b5d96c, 0xb1b1dc6d, 0x6d6d375a, 0x45451550, + 0x3636b98f, 0x6c6c771b, 0xbebe13ad, 0x4a4ada90, 0xeeee57b9, 0x7777a9de, 0xf2f24cbe, 0xfdfd837e, 0x44445511, 0x6767bdda, 0x71712c5d, 0x05054540, 0x7c7c631f, 0x40405010, 0x6969325b, 0x6363b8db, + 0x2828220a, 0x0707c5c2, 0xc4c4f531, 0x2222a88a, 0x969631a7, 0x3737f9ce, 0xeded977a, 0xf6f649bf, 0xb4b4992d, 0xd1d1a475, 0x434390d3, 0x48485a12, 0xe2e258ba, 0x979771e6, 0xd2d264b6, 0xc2c270b2, + 0x2626ad8b, 0xa5a5cd68, 0x5e5ecb95, 0x2929624b, 0x30303c0c, 0x5a5ace94, 0xddddab76, 0xf9f9867f, 0x9595f164, 0xe6e65dbb, 0xc7c735f2, 0x24242d09, 0x1717d1c6, 0xb9b9d66f, 0x1b1bdec5, 0x12129486, + 0x60607818, 0xc3c330f3, 0xf5f5897c, 0xb3b35cef, 0xe8e8d23a, 0x7373acdf, 0x3535794c, 0x8080a020, 0xe5e59d78, 0xbbbb56ed, 0x7d7d235e, 0xf8f8c63e, 0x5f5f8bd4, 0x2f2fe7c8, 0xe4e4dd39, 0x21216849, +} + +var sbox2 = [256]uint32{ + 0x5b8ed55b, 0x42d09242, 0xa74deaa7, 0xfb06fdfb, 0x33fccf33, 0x8765e287, 0xf4c93df4, 0xde6bb5de, 0x584e1658, 0xda6eb4da, 0x50441450, 0x0bcac10b, 0xa08828a0, 0xef17f8ef, 0xb09c2cb0, 0x14110514, + 0xac872bac, 0x9dfb669d, 0x6af2986a, 0xd9ae77d9, 0xa8822aa8, 0xfa46bcfa, 0x10140410, 0x0fcfc00f, 0xaa02a8aa, 0x11544511, 0x4c5f134c, 0x98be2698, 0x256d4825, 0x1a9e841a, 0x181e0618, 0x66fd9b66, + 0x72ec9e72, 0x094a4309, 0x41105141, 0xd324f7d3, 0x46d59346, 0xbf53ecbf, 0x62f89a62, 0xe9927be9, 0xccff33cc, 0x51045551, 0x2c270b2c, 0x0d4f420d, 0xb759eeb7, 0x3ff3cc3f, 0xb21caeb2, 0x89ea6389, + 0x9374e793, 0xce7fb1ce, 0x706c1c70, 0xa60daba6, 0x27edca27, 0x20280820, 0xa348eba3, 0x56c19756, 0x02808202, 0x7fa3dc7f, 0x52c49652, 0xeb12f9eb, 0xd5a174d5, 0x3eb38d3e, 0xfcc33ffc, 0x9a3ea49a, + 0x1d5b461d, 0x1c1b071c, 0x9e3ba59e, 0xf30cfff3, 0xcf3ff0cf, 0xcdbf72cd, 0x5c4b175c, 0xea52b8ea, 0x0e8f810e, 0x653d5865, 0xf0cc3cf0, 0x647d1964, 0x9b7ee59b, 0x16918716, 0x3d734e3d, 0xa208aaa2, + 0xa1c869a1, 0xadc76aad, 0x06858306, 0xca7ab0ca, 0xc5b570c5, 0x91f46591, 0x6bb2d96b, 0x2ea7892e, 0xe318fbe3, 0xaf47e8af, 0x3c330f3c, 0x2d674a2d, 0xc1b071c1, 0x590e5759, 0x76e99f76, 0xd4e135d4, + 0x78661e78, 0x90b42490, 0x38360e38, 0x79265f79, 0x8def628d, 0x61385961, 0x4795d247, 0x8a2aa08a, 0x94b12594, 0x88aa2288, 0xf18c7df1, 0xecd73bec, 0x04050104, 0x84a52184, 0xe19879e1, 0x1e9b851e, + 0x5384d753, 0x00000000, 0x195e4719, 0x5d0b565d, 0x7ee39d7e, 0x4f9fd04f, 0x9cbb279c, 0x491a5349, 0x317c4d31, 0xd8ee36d8, 0x080a0208, 0x9f7be49f, 0x8220a282, 0x13d4c713, 0x23e8cb23, 0x7ae69c7a, + 0xab42e9ab, 0xfe43bdfe, 0x2aa2882a, 0x4b9ad14b, 0x01404101, 0x1fdbc41f, 0xe0d838e0, 0xd661b7d6, 0x8e2fa18e, 0xdf2bf4df, 0xcb3af1cb, 0x3bf6cd3b, 0xe71dfae7, 0x85e56085, 0x54411554, 0x8625a386, + 0x8360e383, 0xba16acba, 0x75295c75, 0x9234a692, 0x6ef7996e, 0xd0e434d0, 0x68721a68, 0x55015455, 0xb619afb6, 0x4edf914e, 0xc8fa32c8, 0xc0f030c0, 0xd721f6d7, 0x32bc8e32, 0xc675b3c6, 0x8f6fe08f, + 0x74691d74, 0xdb2ef5db, 0x8b6ae18b, 0xb8962eb8, 0x0a8a800a, 0x99fe6799, 0x2be2c92b, 0x81e06181, 0x03c0c303, 0xa48d29a4, 0x8caf238c, 0xae07a9ae, 0x34390d34, 0x4d1f524d, 0x39764f39, 0xbdd36ebd, + 0x5781d657, 0x6fb7d86f, 0xdceb37dc, 0x15514415, 0x7ba6dd7b, 0xf709fef7, 0x3ab68c3a, 0xbc932fbc, 0x0c0f030c, 0xff03fcff, 0xa9c26ba9, 0xc9ba73c9, 0xb5d96cb5, 0xb1dc6db1, 0x6d375a6d, 0x45155045, + 0x36b98f36, 0x6c771b6c, 0xbe13adbe, 0x4ada904a, 0xee57b9ee, 0x77a9de77, 0xf24cbef2, 0xfd837efd, 0x44551144, 0x67bdda67, 0x712c5d71, 0x05454005, 0x7c631f7c, 0x40501040, 0x69325b69, 0x63b8db63, + 0x28220a28, 0x07c5c207, 0xc4f531c4, 0x22a88a22, 0x9631a796, 0x37f9ce37, 0xed977aed, 0xf649bff6, 0xb4992db4, 0xd1a475d1, 0x4390d343, 0x485a1248, 0xe258bae2, 0x9771e697, 0xd264b6d2, 0xc270b2c2, + 0x26ad8b26, 0xa5cd68a5, 0x5ecb955e, 0x29624b29, 0x303c0c30, 0x5ace945a, 0xddab76dd, 0xf9867ff9, 0x95f16495, 0xe65dbbe6, 0xc735f2c7, 0x242d0924, 0x17d1c617, 0xb9d66fb9, 0x1bdec51b, 0x12948612, + 0x60781860, 0xc330f3c3, 0xf5897cf5, 0xb35cefb3, 0xe8d23ae8, 0x73acdf73, 0x35794c35, 0x80a02080, 0xe59d78e5, 0xbb56edbb, 0x7d235e7d, 0xf8c63ef8, 0x5f8bd45f, 0x2fe7c82f, 0xe4dd39e4, 0x21684921, +} + +var sbox3 = [256]uint32{ + 0x8ed55b5b, 0xd0924242, 0x4deaa7a7, 0x06fdfbfb, 0xfccf3333, 0x65e28787, 0xc93df4f4, 0x6bb5dede, 0x4e165858, 0x6eb4dada, 0x44145050, 0xcac10b0b, 0x8828a0a0, 0x17f8efef, 0x9c2cb0b0, 0x11051414, + 0x872bacac, 0xfb669d9d, 0xf2986a6a, 0xae77d9d9, 0x822aa8a8, 0x46bcfafa, 0x14041010, 0xcfc00f0f, 0x02a8aaaa, 0x54451111, 0x5f134c4c, 0xbe269898, 0x6d482525, 0x9e841a1a, 0x1e061818, 0xfd9b6666, + 0xec9e7272, 0x4a430909, 0x10514141, 0x24f7d3d3, 0xd5934646, 0x53ecbfbf, 0xf89a6262, 0x927be9e9, 0xff33cccc, 0x04555151, 0x270b2c2c, 0x4f420d0d, 0x59eeb7b7, 0xf3cc3f3f, 0x1caeb2b2, 0xea638989, + 0x74e79393, 0x7fb1cece, 0x6c1c7070, 0x0daba6a6, 0xedca2727, 0x28082020, 0x48eba3a3, 0xc1975656, 0x80820202, 0xa3dc7f7f, 0xc4965252, 0x12f9ebeb, 0xa174d5d5, 0xb38d3e3e, 0xc33ffcfc, 0x3ea49a9a, + 0x5b461d1d, 0x1b071c1c, 0x3ba59e9e, 0x0cfff3f3, 0x3ff0cfcf, 0xbf72cdcd, 0x4b175c5c, 0x52b8eaea, 0x8f810e0e, 0x3d586565, 0xcc3cf0f0, 0x7d196464, 0x7ee59b9b, 0x91871616, 0x734e3d3d, 0x08aaa2a2, + 0xc869a1a1, 0xc76aadad, 0x85830606, 0x7ab0caca, 0xb570c5c5, 0xf4659191, 0xb2d96b6b, 0xa7892e2e, 0x18fbe3e3, 0x47e8afaf, 0x330f3c3c, 0x674a2d2d, 0xb071c1c1, 0x0e575959, 0xe99f7676, 0xe135d4d4, + 0x661e7878, 0xb4249090, 0x360e3838, 0x265f7979, 0xef628d8d, 0x38596161, 0x95d24747, 0x2aa08a8a, 0xb1259494, 0xaa228888, 0x8c7df1f1, 0xd73becec, 0x05010404, 0xa5218484, 0x9879e1e1, 0x9b851e1e, + 0x84d75353, 0x00000000, 0x5e471919, 0x0b565d5d, 0xe39d7e7e, 0x9fd04f4f, 0xbb279c9c, 0x1a534949, 0x7c4d3131, 0xee36d8d8, 0x0a020808, 0x7be49f9f, 0x20a28282, 0xd4c71313, 0xe8cb2323, 0xe69c7a7a, + 0x42e9abab, 0x43bdfefe, 0xa2882a2a, 0x9ad14b4b, 0x40410101, 0xdbc41f1f, 0xd838e0e0, 0x61b7d6d6, 0x2fa18e8e, 0x2bf4dfdf, 0x3af1cbcb, 0xf6cd3b3b, 0x1dfae7e7, 0xe5608585, 0x41155454, 0x25a38686, + 0x60e38383, 0x16acbaba, 0x295c7575, 0x34a69292, 0xf7996e6e, 0xe434d0d0, 0x721a6868, 0x01545555, 0x19afb6b6, 0xdf914e4e, 0xfa32c8c8, 0xf030c0c0, 0x21f6d7d7, 0xbc8e3232, 0x75b3c6c6, 0x6fe08f8f, + 0x691d7474, 0x2ef5dbdb, 0x6ae18b8b, 0x962eb8b8, 0x8a800a0a, 0xfe679999, 0xe2c92b2b, 0xe0618181, 0xc0c30303, 0x8d29a4a4, 0xaf238c8c, 0x07a9aeae, 0x390d3434, 0x1f524d4d, 0x764f3939, 0xd36ebdbd, + 0x81d65757, 0xb7d86f6f, 0xeb37dcdc, 0x51441515, 0xa6dd7b7b, 0x09fef7f7, 0xb68c3a3a, 0x932fbcbc, 0x0f030c0c, 0x03fcffff, 0xc26ba9a9, 0xba73c9c9, 0xd96cb5b5, 0xdc6db1b1, 0x375a6d6d, 0x15504545, + 0xb98f3636, 0x771b6c6c, 0x13adbebe, 0xda904a4a, 0x57b9eeee, 0xa9de7777, 0x4cbef2f2, 0x837efdfd, 0x55114444, 0xbdda6767, 0x2c5d7171, 0x45400505, 0x631f7c7c, 0x50104040, 0x325b6969, 0xb8db6363, + 0x220a2828, 0xc5c20707, 0xf531c4c4, 0xa88a2222, 0x31a79696, 0xf9ce3737, 0x977aeded, 0x49bff6f6, 0x992db4b4, 0xa475d1d1, 0x90d34343, 0x5a124848, 0x58bae2e2, 0x71e69797, 0x64b6d2d2, 0x70b2c2c2, + 0xad8b2626, 0xcd68a5a5, 0xcb955e5e, 0x624b2929, 0x3c0c3030, 0xce945a5a, 0xab76dddd, 0x867ff9f9, 0xf1649595, 0x5dbbe6e6, 0x35f2c7c7, 0x2d092424, 0xd1c61717, 0xd66fb9b9, 0xdec51b1b, 0x94861212, + 0x78186060, 0x30f3c3c3, 0x897cf5f5, 0x5cefb3b3, 0xd23ae8e8, 0xacdf7373, 0x794c3535, 0xa0208080, 0x9d78e5e5, 0x56edbbbb, 0x235e7d7d, 0xc63ef8f8, 0x8bd45f5f, 0xe7c82f2f, 0xdd39e4e4, 0x68492121, +} + +func rl(x uint32, i uint8) uint32 { return (x << (i % 32)) | (x >> (32 - (i % 32))) } + +func l0(b uint32) uint32 { return b ^ rl(b, 13) ^ rl(b, 23) } + +func feistel0(x0, x1, x2, x3, rk uint32) uint32 { return x0 ^ l0(p(x1^x2^x3^rk)) } + +//非线性变换τ(.) +func p(a uint32) uint32 { + return (uint32(sbox[a>>24]) << 24) ^ (uint32(sbox[(a>>16)&0xff]) << 16) ^ (uint32(sbox[(a>>8)&0xff]) << 8) ^ uint32(sbox[(a)&0xff]) +} + +func permuteInitialBlock(b []uint32, block []byte) { + for i := 0; i < 4; i++ { + b[i] = (uint32(block[i*4]) << 24) | (uint32(block[i*4+1]) << 16) | + (uint32(block[i*4+2]) << 8) | (uint32(block[i*4+3])) + } +} + +func permuteFinalBlock(b []byte, block []uint32) { + for i := 0; i < 4; i++ { + b[i*4] = uint8(block[i] >> 24) + b[i*4+1] = uint8(block[i] >> 16) + b[i*4+2] = uint8(block[i] >> 8) + b[i*4+3] = uint8(block[i]) + } +} + +//修改后的加密核心函数 +func cryptBlock(subkeys []uint32, b []uint32, r []byte, dst, src []byte, decrypt bool) { + permuteInitialBlock(b, src) + + // bounds check elimination in major encryption loop + // https://go101.org/article/bounds-check-elimination.html + _ = b[3] + if decrypt { + for i := 0; i < 8; i++ { + s := subkeys[31-4*i-3 : 31-4*i-3+4] + x := b[1] ^ b[2] ^ b[3] ^ s[3] + b[0] = b[0] ^ sbox0[x&0xff] ^ sbox1[(x>>8)&0xff] ^ sbox2[(x>>16)&0xff] ^ sbox3[(x>>24)&0xff] + x = b[0] ^ b[2] ^ b[3] ^ s[2] + b[1] = b[1] ^ sbox0[x&0xff] ^ sbox1[(x>>8)&0xff] ^ sbox2[(x>>16)&0xff] ^ sbox3[(x>>24)&0xff] + x = b[0] ^ b[1] ^ b[3] ^ s[1] + b[2] = b[2] ^ sbox0[x&0xff] ^ sbox1[(x>>8)&0xff] ^ sbox2[(x>>16)&0xff] ^ sbox3[(x>>24)&0xff] + x = b[1] ^ b[2] ^ b[0] ^ s[0] + b[3] = b[3] ^ sbox0[x&0xff] ^ sbox1[(x>>8)&0xff] ^ sbox2[(x>>16)&0xff] ^ sbox3[(x>>24)&0xff] + } + } else { + for i := 0; i < 8; i++ { + s := subkeys[4*i : 4*i+4] + x := b[1] ^ b[2] ^ b[3] ^ s[0] + b[0] = b[0] ^ sbox0[x&0xff] ^ sbox1[(x>>8)&0xff] ^ sbox2[(x>>16)&0xff] ^ sbox3[(x>>24)&0xff] + x = b[0] ^ b[2] ^ b[3] ^ s[1] + b[1] = b[1] ^ sbox0[x&0xff] ^ sbox1[(x>>8)&0xff] ^ sbox2[(x>>16)&0xff] ^ sbox3[(x>>24)&0xff] + x = b[0] ^ b[1] ^ b[3] ^ s[2] + b[2] = b[2] ^ sbox0[x&0xff] ^ sbox1[(x>>8)&0xff] ^ sbox2[(x>>16)&0xff] ^ sbox3[(x>>24)&0xff] + x = b[1] ^ b[2] ^ b[0] ^ s[3] + b[3] = b[3] ^ sbox0[x&0xff] ^ sbox1[(x>>8)&0xff] ^ sbox2[(x>>16)&0xff] ^ sbox3[(x>>24)&0xff] + } + } + b[0], b[1], b[2], b[3] = b[3], b[2], b[1], b[0] + permuteFinalBlock(r, b) + copy(dst, r) +} + +func generateSubKeys(key []byte) []uint32 { + subkeys := make([]uint32, 32) + b := make([]uint32, 4) + permuteInitialBlock(b, key) + b[0] ^= fk[0] + b[1] ^= fk[1] + b[2] ^= fk[2] + b[3] ^= fk[3] + for i := 0; i < 32; i++ { + subkeys[i] = feistel0(b[0], b[1], b[2], b[3], ck[i]) + b[0], b[1], b[2], b[3] = b[1], b[2], b[3], subkeys[i] + } + return subkeys +} + +// NewCipher creates and returns a new cipher.Block. +func NewCipher(key []byte) (cipher.Block, error) { + if len(key) != BlockSize { + return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) + } + c := new(Sm4Cipher) + c.subkeys = generateSubKeys(key) + c.block1 = make([]uint32, 4) + c.block2 = make([]byte, 16) + return c, nil +} + +func (c *Sm4Cipher) BlockSize() int { + return BlockSize +} + +func (c *Sm4Cipher) Encrypt(dst, src []byte) { + cryptBlock(c.subkeys, c.block1, c.block2, dst, src, false) +} + +func (c *Sm4Cipher) Decrypt(dst, src []byte) { + cryptBlock(c.subkeys, c.block1, c.block2, dst, src, true) +} + +func xor(in, iv []byte) (out []byte) { + if len(in) != len(iv) { + return nil + } + + out = make([]byte, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[i] ^ iv[i] + } + return +} + +func pkcs7Padding(src []byte) []byte { + padding := BlockSize - len(src)%BlockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(src, padtext...) +} + +func pkcs7UnPadding(src []byte) ([]byte, error) { + length := len(src) + if length == 0 { + return nil, errors.New("Invalid pkcs7 padding (len(padtext) == 0)") + } + unpadding := int(src[length-1]) + if unpadding > BlockSize || unpadding == 0 { + return nil, errors.New("Invalid pkcs7 padding (unpadding > BlockSize || unpadding == 0)") + } + + pad := src[len(src)-unpadding:] + for i := 0; i < unpadding; i++ { + if pad[i] != byte(unpadding) { + return nil, errors.New("Invalid pkcs7 padding (pad[i] != unpadding)") + } + } + + return src[:(length - unpadding)], nil +} +func SetIV(iv []byte) error { + if len(iv) != BlockSize { + return errors.New("SM4: invalid iv size") + } + IV = iv + return nil +} + +func Sm4Cbc(key []byte, in []byte, mode bool) (out []byte, err error) { + if len(key) != BlockSize { + return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) + } + var inData []byte + if mode { + inData = pkcs7Padding(in) + } else { + inData = in + } + iv := make([]byte, BlockSize) + copy(iv, IV) + out = make([]byte, len(inData)) + c, err := NewCipher(key) + if err != nil { + return nil, err + } + if mode { + for i := 0; i < len(inData)/16; i++ { + in_tmp := xor(inData[i*16:i*16+16], iv) + out_tmp := make([]byte, 16) + c.Encrypt(out_tmp, in_tmp) + copy(out[i*16:i*16+16], out_tmp) + iv = out_tmp + } + } else { + for i := 0; i < len(inData)/16; i++ { + in_tmp := inData[i*16 : i*16+16] + out_tmp := make([]byte, 16) + c.Decrypt(out_tmp, in_tmp) + out_tmp = xor(out_tmp, iv) + copy(out[i*16:i*16+16], out_tmp) + iv = in_tmp + } + out, _ = pkcs7UnPadding(out) + } + + return out, nil +} +func Sm4Ecb(key []byte, in []byte, mode bool) (out []byte, err error) { + if len(key) != BlockSize { + return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) + } + var inData []byte + if mode { + inData = pkcs7Padding(in) + } else { + inData = in + } + out = make([]byte, len(inData)) + c, err := NewCipher(key) + if err != nil { + return nil, err + } + if mode { + for i := 0; i < len(inData)/16; i++ { + in_tmp := inData[i*16 : i*16+16] + out_tmp := make([]byte, 16) + c.Encrypt(out_tmp, in_tmp) + copy(out[i*16:i*16+16], out_tmp) + } + } else { + for i := 0; i < len(inData)/16; i++ { + in_tmp := inData[i*16 : i*16+16] + out_tmp := make([]byte, 16) + c.Decrypt(out_tmp, in_tmp) + copy(out[i*16:i*16+16], out_tmp) + } + out, _ = pkcs7UnPadding(out) + } + + return out, nil +} + +//密码反馈模式(Cipher FeedBack (CFB)) +//https://blog.csdn.net/zy_strive_2012/article/details/102520356 +//https://blog.csdn.net/sinat_23338865/article/details/72869841 +func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) { + if len(key) != BlockSize { + return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) + } + var inData []byte + if mode { + inData = pkcs7Padding(in) + } else { + inData = in + } + + out = make([]byte, len(inData)) + c, err := NewCipher(key) + if err != nil { + return nil, err + } + + K := make([]byte, BlockSize) + cipherBlock := make([]byte, BlockSize) + plainBlock := make([]byte, BlockSize) + if mode { //加密 + for i := 0; i < len(inData)/16; i++ { + if i == 0 { + c.Encrypt(K, IV) + cipherBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) + copy(out[i*16:i*16+16], cipherBlock) + //copy(cipherBlock,out_tmp) + continue + } + c.Encrypt(K, cipherBlock) + cipherBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) + copy(out[i*16:i*16+16], cipherBlock) + //copy(cipherBlock,out_tmp) + } + + } else { //解密 + var i int = 0 + for ; i < len(inData)/16; i++ { + if i == 0 { + c.Encrypt(K, IV) //这里是加密,而不是调用解密方法Decrypt + plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组 + copy(out[i*16:i*16+16], plainBlock) + continue + } + c.Encrypt(K, inData[(i-1)*16:(i-1)*16+16]) + plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组 + copy(out[i*16:i*16+16], plainBlock) + + } + + out, _ = pkcs7UnPadding(out) + } + + return out, nil +} + +//输出反馈模式(Output feedback, OFB) +//https://blog.csdn.net/chengqiuming/article/details/82390910 +//https://blog.csdn.net/sinat_23338865/article/details/72869841 +func Sm4OFB(key []byte, in []byte, mode bool) (out []byte, err error) { + if len(key) != BlockSize { + return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) + } + var inData []byte + if mode { + inData = pkcs7Padding(in) + } else { + inData = in + } + + out = make([]byte, len(inData)) + c, err := NewCipher(key) + if err != nil { + return nil, err + } + + K := make([]byte, BlockSize) + cipherBlock := make([]byte, BlockSize) + plainBlock := make([]byte, BlockSize) + shiftIV := make([]byte, BlockSize) + if mode { //加密 + for i := 0; i < len(inData)/16; i++ { + if i == 0 { + c.Encrypt(K, IV) + cipherBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) + copy(out[i*16:i*16+16], cipherBlock) + copy(shiftIV, K[:BlockSize]) + continue + } + c.Encrypt(K, shiftIV) + cipherBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) + copy(out[i*16:i*16+16], cipherBlock) + copy(shiftIV, K[:BlockSize]) + } + + } else { //解密 + for i := 0; i < len(inData)/16; i++ { + if i == 0 { + c.Encrypt(K, IV) //这里是加密,而不是调用解密方法Decrypt + plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组 + copy(out[i*16:i*16+16], plainBlock) + copy(shiftIV, K[:BlockSize]) + continue + } + c.Encrypt(K, shiftIV) + plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组 + copy(out[i*16:i*16+16], plainBlock) + copy(shiftIV, K[:BlockSize]) + } + out, _ = pkcs7UnPadding(out) + } + + return out, nil +} diff --git a/vendor/github.com/tjfoc/gmsm/sm4/sm4_gcm.go b/vendor/github.com/tjfoc/gmsm/sm4/sm4_gcm.go new file mode 100644 index 00000000..a5d531d3 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/sm4_gcm.go @@ -0,0 +1,355 @@ +/* +Copyright Hyperledger-TWGC All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +writed by Zhiwei Yan, 2020 Oct +*/ + +package sm4 + +import ( + "errors" + "strconv" +) + +// Sm4GCM SM4 GCM 加解密模式 +// Paper: The Galois/Counter Mode of Operation (GCM) David A. Mcgrew,John Viega .2004. +// key: 对称加密密钥 +// IV: IV向量 +// in: +// A: 附加的可鉴别数据(ADD) +// mode: true - 加密; false - 解密验证 +// +// return: 密文C, 鉴别标签T, 错误 +func Sm4GCM(key []byte, IV, in, A []byte, mode bool) ([]byte, []byte, error) { + if len(key) != BlockSize { + return nil, nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) + } + if mode { + C, T := GCMEncrypt(key, IV, in, A) + return C, T, nil + } else { + P, _T := GCMDecrypt(key, IV, in, A) + return P, _T, nil + } +} + +// GetH 对“0”分组的加密得到 GHASH泛杂凑函数的子密钥 +// key: 对称密钥 +// return: GHASH泛杂凑函数的子密钥 +func GetH(key []byte) (H []byte) { + c, err := NewCipher(key) + if err != nil { + panic(err) + } + + zores := make([]byte, BlockSize) + H = make([]byte, BlockSize) + c.Encrypt(H, zores) + return H +} + +//ut = a + b +func addition(a, b []byte) (out []byte) { + Len := len(a) + if Len != len(b) { + return nil + } + out = make([]byte, Len) + for i := 0; i < Len; i++ { + out[i] = a[i] ^ b[i] + } + return out +} + +func Rightshift(V []byte) { + n := len(V) + for i := n - 1; i >= 0; i-- { + V[i] = V[i] >> 1 + if i != 0 { + V[i] = ((V[i-1] & 0x01) << 7) | V[i] + } + } +} + +func findYi(Y []byte, index int) int { + var temp byte + i := uint(index) + temp = Y[i/8] + temp = temp >> (7 - i%8) + if temp&0x01 == 1 { + return 1 + } else { + return 0 + } +} + +func multiplication(X, Y []byte) (Z []byte) { + + R := make([]byte, BlockSize) + R[0] = 0xe1 + Z = make([]byte, BlockSize) + V := make([]byte, BlockSize) + copy(V, X) + for i := 0; i <= 127; i++ { + if findYi(Y, i) == 1 { + Z = addition(Z, V) + } + if V[BlockSize-1]&0x01 == 0 { + Rightshift(V) + } else { + Rightshift(V) + V = addition(V, R) + } + } + return Z +} + +func GHASH(H []byte, A []byte, C []byte) (X []byte) { + + calculm_v := func(m, v int) (int, int) { + if m == 0 && v != 0 { + m = 1 + v = v * 8 + } else if m != 0 && v == 0 { + v = BlockSize * 8 + } else if m != 0 && v != 0 { + m = m + 1 + v = v * 8 + } else { //m==0 && v==0 + m = 1 + v = 0 + } + return m, v + } + m := len(A) / BlockSize + v := len(A) % BlockSize + m, v = calculm_v(m, v) + + n := len(C) / BlockSize + u := (len(C) % BlockSize) + n, u = calculm_v(n, u) + + //i=0 + X = make([]byte, BlockSize*(m+n+2)) //X0 = 0 + for i := 0; i < BlockSize; i++ { + X[i] = 0x00 + } + + //i=1...m-1 + for i := 1; i <= m-1; i++ { + copy(X[i*BlockSize:i*BlockSize+BlockSize], multiplication(addition(X[(i-1)*BlockSize:(i-1)*BlockSize+BlockSize], A[(i-1)*BlockSize:(i-1)*BlockSize+BlockSize]), H)) //A 1-->m-1 对于数组来说是 0-->m-2 + } + + //i=m + zeros := make([]byte, (128-v)/8) + Am := make([]byte, v/8) + copy(Am[:], A[(m-1)*BlockSize:]) + Am = append(Am, zeros...) + copy(X[m*BlockSize:m*BlockSize+BlockSize], multiplication(addition(X[(m-1)*BlockSize:(m-1)*BlockSize+BlockSize], Am), H)) + + //i=m+1...m+n-1 + for i := m + 1; i <= (m + n - 1); i++ { + copy(X[i*BlockSize:i*BlockSize+BlockSize], multiplication(addition(X[(i-1)*BlockSize:(i-1)*BlockSize+BlockSize], C[(i-m-1)*BlockSize:(i-m-1)*BlockSize+BlockSize]), H)) + } + + //i=m+n + zeros = make([]byte, (128-u)/8) + Cn := make([]byte, u/8) + copy(Cn[:], C[(n-1)*BlockSize:]) + Cn = append(Cn, zeros...) + copy(X[(m+n)*BlockSize:(m+n)*BlockSize+BlockSize], multiplication(addition(X[(m+n-1)*BlockSize:(m+n-1)*BlockSize+BlockSize], Cn), H)) + + //i=m+n+1 + var lenAB []byte + calculateLenToBytes := func(len int) []byte { + data := make([]byte, 8) + data[0] = byte((len >> 56) & 0xff) + data[1] = byte((len >> 48) & 0xff) + data[2] = byte((len >> 40) & 0xff) + data[3] = byte((len >> 32) & 0xff) + data[4] = byte((len >> 24) & 0xff) + data[5] = byte((len >> 16) & 0xff) + data[6] = byte((len >> 8) & 0xff) + data[7] = byte((len >> 0) & 0xff) + return data + } + lenAB = append(lenAB, calculateLenToBytes(len(A))...) + lenAB = append(lenAB, calculateLenToBytes(len(C))...) + copy(X[(m+n+1)*BlockSize:(m+n+1)*BlockSize+BlockSize], multiplication(addition(X[(m+n)*BlockSize:(m+n)*BlockSize+BlockSize], lenAB), H)) + return X[(m+n+1)*BlockSize : (m+n+1)*BlockSize+BlockSize] +} + +// GetY0 生成初始的计数器时钟J0 +// +// H: GHASH自密钥 +// IV: IV向量 +// return: 初始的计数器时钟(J0) +func GetY0(H, IV []byte) []byte { + if len(IV)*8 == 96 { + zero31one1 := []byte{0x00, 0x00, 0x00, 0x01} + IV = append(IV, zero31one1...) + return IV + } else { + return GHASH(H, []byte{}, IV) + } +} + +func incr(n int, Y_i []byte) (Y_ii []byte) { + + Y_ii = make([]byte, BlockSize*n) + copy(Y_ii, Y_i) + + addYone := func(yi, yii []byte) { + copy(yii[:], yi[:]) + + Len := len(yi) + var rc byte = 0x00 + for i := Len - 1; i >= 0; i-- { + if i == Len-1 { + if yii[i] < 0xff { + yii[i] = yii[i] + 0x01 + rc = 0x00 + } else { + yii[i] = 0x00 + rc = 0x01 + } + } else { + if yii[i]+rc < 0xff { + yii[i] = yii[i] + rc + rc = 0x00 + } else { + yii[i] = 0x00 + rc = 0x01 + } + } + } + } + for i := 1; i < n; i++ { //2^32 + addYone(Y_ii[(i-1)*BlockSize:(i-1)*BlockSize+BlockSize], Y_ii[i*BlockSize:i*BlockSize+BlockSize]) + } + return Y_ii +} + +func MSB(len int, S []byte) (out []byte) { + return S[:len/8] +} + +// GCMEncrypt 可鉴别加密函数 (GCM-AE(k)) +// K: 对称密钥 +// IV: IV向量 +// P: 明文 +// A: 附加的鉴别数据 +// +// return: 密文, 鉴别标签 +func GCMEncrypt(K, IV, P, A []byte) (C, T []byte) { + calculm_v := func(m, v int) (int, int) { + if m == 0 && v != 0 { + m = 1 + v = v * 8 + } else if m != 0 && v == 0 { + v = BlockSize * 8 + } else if m != 0 && v != 0 { + m = m + 1 + v = v * 8 + } else { //m==0 && v==0 + m = 1 + v = 0 + } + return m, v + } + n := len(P) / BlockSize + u := len(P) % BlockSize + n, u = calculm_v(n, u) + + // a) 通过对“0”分组的加密得到 GHASH泛杂凑函数的子密钥 + H := GetH(K) + + Y0 := GetY0(H, IV) + + Y := make([]byte, BlockSize*(n+1)) + Y = incr(n+1, Y0) + c, err := NewCipher(K) + if err != nil { + panic(err) + } + Enc := make([]byte, BlockSize) + C = make([]byte, len(P)) + + //i=1...n-1 + for i := 1; i <= n-1; i++ { + c.Encrypt(Enc, Y[i*BlockSize:i*BlockSize+BlockSize]) + + copy(C[(i-1)*BlockSize:(i-1)*BlockSize+BlockSize], addition(P[(i-1)*BlockSize:(i-1)*BlockSize+BlockSize], Enc)) + } + + //i=n + c.Encrypt(Enc, Y[n*BlockSize:n*BlockSize+BlockSize]) + out := MSB(u, Enc) + copy(C[(n-1)*BlockSize:], addition(P[(n-1)*BlockSize:], out)) + + c.Encrypt(Enc, Y0) + + t := 128 + T = MSB(t, addition(Enc, GHASH(H, A, C))) + return C, T +} + +func GCMDecrypt(K, IV, C, A []byte) (P, _T []byte) { + calculm_v := func(m, v int) (int, int) { + if m == 0 && v != 0 { + m = 1 + v = v * 8 + } else if m != 0 && v == 0 { + v = BlockSize * 8 + } else if m != 0 && v != 0 { + m = m + 1 + v = v * 8 + } else { //m==0 && v==0 + m = 1 + v = 0 + } + return m, v + } + + H := GetH(K) + + Y0 := GetY0(H, IV) + + Enc := make([]byte, BlockSize) + c, err := NewCipher(K) + if err != nil { + panic(err) + } + c.Encrypt(Enc, Y0) + t := 128 + _T = MSB(t, addition(Enc, GHASH(H, A, C))) + + n := len(C) / BlockSize + u := len(C) % BlockSize + n, u = calculm_v(n, u) + Y := make([]byte, BlockSize*(n+1)) + Y = incr(n+1, Y0) + + P = make([]byte, BlockSize*n) + for i := 1; i <= n; i++ { + c.Encrypt(Enc, Y[i*BlockSize:i*BlockSize+BlockSize]) + copy(P[(i-1)*BlockSize:(i-1)*BlockSize+BlockSize], addition(C[(i-1)*BlockSize:(i-1)*BlockSize+BlockSize], Enc)) + } + + c.Encrypt(Enc, Y[n*BlockSize:n*BlockSize+BlockSize]) + out := MSB(u, Enc) + copy(P[(n-1)*BlockSize:], addition(C[(n-1)*BlockSize:], out)) + + return P, _T +} diff --git a/vendor/github.com/tjfoc/gmsm/sm4/sm4_gcm_test.go b/vendor/github.com/tjfoc/gmsm/sm4/sm4_gcm_test.go new file mode 100644 index 00000000..a5e0900c --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/sm4_gcm_test.go @@ -0,0 +1,61 @@ +/* +Copyright Hyperledger-TWGC All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +writed by Zhiwei Yan, 2020 Oct +*/ + +package sm4 + +import ( + "bytes" + "fmt" + "testing" +) + + +func TestSM4GCM(t *testing.T){ + key := []byte("1234567890abcdef") + data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} + IV :=make([]byte,BlockSize) + testA:=[][]byte{ // the length of the A can be random + []byte{}, + []byte{0x01, 0x23, 0x45, 0x67, 0x89}, + []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10}, + } + for _,A:=range testA{ + gcmMsg,T,err:=Sm4GCM(key,IV,data,A,true) + if err !=nil{ + t.Errorf("sm4 enc error:%s", err) + } + fmt.Printf("gcmMsg = %x\n", gcmMsg) + gcmDec,T_,err:=Sm4GCM(key,IV,gcmMsg,A,false) + if err != nil{ + t.Errorf("sm4 dec error:%s", err) + } + fmt.Printf("gcmDec = %x\n", gcmDec) + if bytes.Compare(T,T_)==0{ + fmt.Println("authentication successed") + } + //Failed Test : if we input the different A , that will be a falied result. + A= []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd} + gcmDec,T_,err=Sm4GCM(key,IV,gcmMsg,A ,false) + if err != nil{ + t.Errorf("sm4 dec error:%s", err) + } + if bytes.Compare(T,T_)!=0{ + fmt.Println("authentication failed") + } + } + +} \ No newline at end of file diff --git a/vendor/github.com/tjfoc/gmsm/sm4/sm4_test.go b/vendor/github.com/tjfoc/gmsm/sm4/sm4_test.go new file mode 100644 index 00000000..6115ba14 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/sm4_test.go @@ -0,0 +1,153 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sm4 + +import ( + "fmt" + "reflect" + "testing" +) + +func TestSM4(t *testing.T) { + key := []byte("1234567890abcdef") + + fmt.Printf("key = %v\n", key) + data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} + err := WriteKeyToPemFile("key.pem", key, nil) + if err != nil { + t.Fatalf("WriteKeyToPem error") + } + key, err = ReadKeyFromPemFile("key.pem", nil) + fmt.Printf("key = %v\n", key) + if err != nil { + t.Fatal(err) + } + fmt.Printf("data = %x\n", data) + ecbMsg, err := Sm4Ecb(key, data, true) + if err != nil { + t.Errorf("sm4 enc error:%s", err) + return + } + fmt.Printf("ecbMsg = %x\n", ecbMsg) + iv := []byte("0000000000000000") + err = SetIV(iv) + fmt.Printf("err = %v\n", err) + ecbDec, err := Sm4Ecb(key, ecbMsg, false) + if err != nil { + t.Errorf("sm4 dec error:%s", err) + return + } + fmt.Printf("ecbDec = %x\n", ecbDec) + if !testCompare(data, ecbDec) { + t.Errorf("sm4 self enc and dec failed") + } + cbcMsg, err := Sm4Cbc(key, data, true) + if err != nil { + t.Errorf("sm4 enc error:%s", err) + } + fmt.Printf("cbcMsg = %x\n", cbcMsg) + cbcDec, err := Sm4Cbc(key, cbcMsg, false) + if err != nil { + t.Errorf("sm4 dec error:%s", err) + return + } + fmt.Printf("cbcDec = %x\n", cbcDec) + if !testCompare(data, cbcDec) { + t.Errorf("sm4 self enc and dec failed") + } + + cbcMsg, err = Sm4CFB(key, data, true) + if err != nil { + t.Errorf("sm4 enc error:%s", err) + } + fmt.Printf("cbcCFB = %x\n", cbcMsg) + + cbcCfb, err := Sm4CFB(key, cbcMsg, false) + if err != nil { + t.Errorf("sm4 dec error:%s", err) + return + } + fmt.Printf("cbcCFB = %x\n", cbcCfb) + + cbcMsg, err = Sm4OFB(key, data, true) + if err != nil { + t.Errorf("sm4 enc error:%s", err) + } + fmt.Printf("cbcOFB = %x\n", cbcMsg) + + cbcOfc, err := Sm4OFB(key, cbcMsg, false) + if err != nil { + t.Errorf("sm4 dec error:%s", err) + return + } + fmt.Printf("cbcOFB = %x\n", cbcOfc) +} + +func BenchmarkSM4(t *testing.B) { + t.ReportAllocs() + key := []byte("1234567890abcdef") + data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} + err := WriteKeyToPemFile("key.pem", key, nil) + if err != nil { + t.Fatalf("WriteKeyToPem error") + } + key, err = ReadKeyFromPemFile("key.pem", nil) + if err != nil { + t.Fatal(err) + } + c, err := NewCipher(key) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < t.N; i++ { + d0 := make([]byte, 16) + c.Encrypt(d0, data) + d1 := make([]byte, 16) + c.Decrypt(d1, d0) + } +} + +func TestErrKeyLen(t *testing.T) { + fmt.Printf("\n--------------test key len------------------") + key := []byte("1234567890abcdefg") + _, err := NewCipher(key) + if err != nil { + fmt.Println("\nError key len !") + } + key = []byte("1234") + _, err = NewCipher(key) + if err != nil { + fmt.Println("Error key len !") + } + fmt.Println("------------------end----------------------") +} + +func testCompare(key1, key2 []byte) bool { + if len(key1) != len(key2) { + return false + } + for i, v := range key1 { + if i == 1 { + fmt.Println("type of v", reflect.TypeOf(v)) + } + a := key2[i] + if a != v { + return false + } + } + return true +} diff --git a/vendor/github.com/tjfoc/gmsm/sm4/utils.go b/vendor/github.com/tjfoc/gmsm/sm4/utils.go new file mode 100644 index 00000000..93060f40 --- /dev/null +++ b/vendor/github.com/tjfoc/gmsm/sm4/utils.go @@ -0,0 +1,86 @@ +package sm4 + +import ( + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + "io/ioutil" +) + +// ReadKeyFromPem will return SM4Key from PEM format data. +func ReadKeyFromPem(data []byte, pwd []byte) (SM4Key, error) { + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("SM4: pem decode failed") + } + if x509.IsEncryptedPEMBlock(block) { + if block.Type != "SM4 ENCRYPTED KEY" { + return nil, errors.New("SM4: unknown type") + } + if pwd == nil { + return nil, errors.New("SM4: need passwd") + } + data, err := x509.DecryptPEMBlock(block, pwd) + if err != nil { + return nil, err + } + return data, nil + } + if block.Type != "SM4 KEY" { + return nil, errors.New("SM4: unknown type") + } + return block.Bytes, nil +} + +// ReadKeyFromPemFile will return SM4Key from filename that saved PEM format data. +func ReadKeyFromPemFile(FileName string, pwd []byte) (SM4Key, error) { + data, err := ioutil.ReadFile(FileName) + if err != nil { + return nil, err + } + return ReadKeyFromPem(data, pwd) +} + +// WriteKeyToPem will convert SM4Key to PEM format data and return it. +func WriteKeyToPem(key SM4Key, pwd []byte) ([]byte, error) { + if pwd != nil { + block, err := x509.EncryptPEMBlock(rand.Reader, + "SM4 ENCRYPTED KEY", key, pwd, x509.PEMCipherAES256) //Use AES256 algorithms to encrypt SM4KEY + if err != nil { + return nil, err + } + return pem.EncodeToMemory(block), nil + } else { + block := &pem.Block{ + Type: "SM4 KEY", + Bytes: key, + } + return pem.EncodeToMemory(block), nil + } +} + +// WriteKeyToPemFile will convert SM4Key to PEM format data, then write it +// into the input filename. +func WriteKeyToPemFile(FileName string, key SM4Key, pwd []byte) error { + var block *pem.Block + var err error + if pwd != nil { + block, err = x509.EncryptPEMBlock(rand.Reader, + "SM4 ENCRYPTED KEY", key, pwd, x509.PEMCipherAES256) + if err != nil { + return err + } + } else { + block = &pem.Block{ + Type: "SM4 KEY", + Bytes: key, + } + } + pemBytes := pem.EncodeToMemory(block) + err = ioutil.WriteFile(FileName, pemBytes, 0666) + if err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/txthinking/runnergroup/LICENSE b/vendor/github.com/txthinking/runnergroup/LICENSE new file mode 100644 index 00000000..3c27036d --- /dev/null +++ b/vendor/github.com/txthinking/runnergroup/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2020-present Cloud https://www.txthinking.com + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/vendor/github.com/txthinking/runnergroup/example_test.go b/vendor/github.com/txthinking/runnergroup/example_test.go new file mode 100644 index 00000000..bea41d41 --- /dev/null +++ b/vendor/github.com/txthinking/runnergroup/example_test.go @@ -0,0 +1,45 @@ +package runnergroup_test + +import ( + "context" + "log" + "net/http" + "time" + + "github.com/txthinking/runnergroup" +) + +func Example() { + g := runnergroup.New() + + s := &http.Server{ + Addr: ":9991", + } + g.Add(&runnergroup.Runner{ + Start: func() error { + return s.ListenAndServe() + }, + Stop: func() error { + return s.Shutdown(context.Background()) + }, + }) + + s1 := &http.Server{ + Addr: ":9992", + } + g.Add(&runnergroup.Runner{ + Start: func() error { + return s1.ListenAndServe() + }, + Stop: func() error { + return s1.Shutdown(context.Background()) + }, + }) + + go func() { + time.Sleep(5 * time.Second) + log.Println(g.Done()) + }() + log.Println(g.Wait()) + // Output: +} diff --git a/vendor/github.com/txthinking/runnergroup/readme.md b/vendor/github.com/txthinking/runnergroup/readme.md new file mode 100644 index 00000000..564c12eb --- /dev/null +++ b/vendor/github.com/txthinking/runnergroup/readme.md @@ -0,0 +1,65 @@ +## RunnerGroup + +[![GoDoc](https://img.shields.io/badge/Go-Doc-blue.svg)](https://godoc.org/github.com/txthinking/runnergroup) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/txthinking/runnergroup/blob/master/LICENSE) + +RunnerGroup is like [sync.WaitGroup](https://pkg.go.dev/sync?tab=doc#WaitGroup), the diffrence is if one task stops, all will be stopped. + +❤️ A project by [txthinking.com](https://www.txthinking.com) + +### Install + + $ go get github.com/txthinking/runnergroup + +### Example + +``` +import ( + "context" + "log" + "net/http" + "time" + + "github.com/txthinking/runnergroup" +) + +func Example() { + g := runnergroup.New() + + s := &http.Server{ + Addr: ":9991", + } + g.Add(&runnergroup.Runner{ + Start: func() error { + return s.ListenAndServe() + }, + Stop: func() error { + return s.Shutdown(context.Background()) + }, + }) + + s1 := &http.Server{ + Addr: ":9992", + } + g.Add(&runnergroup.Runner{ + Start: func() error { + return s1.ListenAndServe() + }, + Stop: func() error { + return s1.Shutdown(context.Background()) + }, + }) + + go func() { + time.Sleep(5 * time.Second) + log.Println(g.Done()) + }() + log.Println(g.Wait()) + // Output: +} + +``` + +## License + +Licensed under The MIT License diff --git a/vendor/github.com/txthinking/runnergroup/runnergroup.go b/vendor/github.com/txthinking/runnergroup/runnergroup.go new file mode 100644 index 00000000..7b3efadd --- /dev/null +++ b/vendor/github.com/txthinking/runnergroup/runnergroup.go @@ -0,0 +1,96 @@ +package runnergroup + +import ( + "sync" + "time" +) + +// RunnerGroup is like sync.WaitGroup, +// the diffrence is if one task stops, all will be stopped. +type RunnerGroup struct { + runners []*Runner + done chan byte +} + +type Runner struct { + // Start is a blocking function. + Start func() error + // Stop is not a blocking function, if Stop called, must let Start return. + // Notice: Stop maybe called multi times even if before start. + Stop func() error + lock sync.Mutex + status int +} + +func New() *RunnerGroup { + g := &RunnerGroup{} + g.runners = make([]*Runner, 0) + g.done = make(chan byte) + return g +} + +func (g *RunnerGroup) Add(r *Runner) { + g.runners = append(g.runners, r) +} + +// Call Wait after all task have been added, +// Return the first ended start's result. +func (g *RunnerGroup) Wait() error { + e := make(chan error) + for _, v := range g.runners { + v.status = 1 + go func(v *Runner) { + err := v.Start() + v.lock.Lock() + v.status = 0 + v.lock.Unlock() + select { + case <-g.done: + case e <- err: + } + }(v) + } + err := <-e + for _, v := range g.runners { + for { + v.lock.Lock() + if v.status == 0 { + v.lock.Unlock() + break + } + v.lock.Unlock() + _ = v.Stop() + time.Sleep(300 * time.Millisecond) + } + } + close(g.done) + return err +} + +// Call Done if you want to stop all. +// return the stop's return which is not nil, do not guarantee, +// because starts may ended caused by itself. +func (g *RunnerGroup) Done() error { + if len(g.runners) == 0 { + return nil + } + var e error + for _, v := range g.runners { + for { + v.lock.Lock() + if v.status == 0 { + v.lock.Unlock() + break + } + v.lock.Unlock() + if err := v.Stop(); err != nil { + if e == nil { + e = err + } + } + time.Sleep(300 * time.Millisecond) + } + } + <-g.done + return e +} diff --git a/vendor/github.com/txthinking/socks5/.github/FUNDING.yml b/vendor/github.com/txthinking/socks5/.github/FUNDING.yml new file mode 100644 index 00000000..b6c3bb2c --- /dev/null +++ b/vendor/github.com/txthinking/socks5/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: txthinking +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: diff --git a/vendor/github.com/txthinking/socks5/.github/ISSUE_TEMPLATE b/vendor/github.com/txthinking/socks5/.github/ISSUE_TEMPLATE new file mode 100644 index 00000000..94e28b59 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/.github/ISSUE_TEMPLATE @@ -0,0 +1,10 @@ +#### Describe actual behavior + +#### What is your expected behavior + +#### Specifications like the version of the project, operating system, or hardware + +#### Steps to reproduce the problem +0. +1. +2. diff --git a/vendor/github.com/txthinking/socks5/.github/PULL_REQUEST_TEMPLATE b/vendor/github.com/txthinking/socks5/.github/PULL_REQUEST_TEMPLATE new file mode 100644 index 00000000..489e7e43 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/.github/PULL_REQUEST_TEMPLATE @@ -0,0 +1,8 @@ +Fixes # . + +Changes proposed in this pull request: +- +- +- + +@mentions diff --git a/vendor/github.com/txthinking/socks5/.gitignore b/vendor/github.com/txthinking/socks5/.gitignore new file mode 100644 index 00000000..9b78a338 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/.gitignore @@ -0,0 +1,28 @@ +# IDEs +.vscode +.idea + +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/txthinking/socks5/LICENSE b/vendor/github.com/txthinking/socks5/LICENSE new file mode 100644 index 00000000..03aad451 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2015-present Cloud https://www.txthinking.com + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/txthinking/socks5/README.md b/vendor/github.com/txthinking/socks5/README.md new file mode 100644 index 00000000..0e8c272d --- /dev/null +++ b/vendor/github.com/txthinking/socks5/README.md @@ -0,0 +1,96 @@ +## socks5 + +[中文](README_ZH.md) + +[![Go Report Card](https://goreportcard.com/badge/github.com/txthinking/socks5)](https://goreportcard.com/report/github.com/txthinking/socks5) +[![GoDoc](https://godoc.org/github.com/txthinking/socks5?status.svg)](https://godoc.org/github.com/txthinking/socks5) +[![Donate](https://img.shields.io/badge/Support-Donate-ff69b4.svg)](https://github.com/sponsors/txthinking) +[![Slack](https://img.shields.io/badge/Join-Slack-ff69b4.svg)](https://docs.google.com/forms/d/e/1FAIpQLSdzMwPtDue3QoezXSKfhW88BXp57wkbDXnLaqokJqLeSWP9vQ/viewform) + +SOCKS Protocol Version 5 Library. + +Full TCP/UDP and IPv4/IPv6 support. +Goals: KISS, less is more, small API, code is like the original protocol. + +❤️ A project by [txthinking.com](https://www.txthinking.com) + +### Install +``` +$ go get github.com/txthinking/socks5 +``` + +### Struct is like concept in protocol + +* Negotiation: + * `type NegotiationRequest struct` + * `func NewNegotiationRequest(methods []byte)`, in client + * `func (r *NegotiationRequest) WriteTo(w io.Writer)`, client writes to server + * `func NewNegotiationRequestFrom(r io.Reader)`, server reads from client + * `type NegotiationReply struct` + * `func NewNegotiationReply(method byte)`, in server + * `func (r *NegotiationReply) WriteTo(w io.Writer)`, server writes to client + * `func NewNegotiationReplyFrom(r io.Reader)`, client reads from server +* User and password negotiation: + * `type UserPassNegotiationRequest struct` + * `func NewUserPassNegotiationRequest(username []byte, password []byte)`, in client + * `func (r *UserPassNegotiationRequest) WriteTo(w io.Writer)`, client writes to server + * `func NewUserPassNegotiationRequestFrom(r io.Reader)`, server reads from client + * `type UserPassNegotiationReply struct` + * `func NewUserPassNegotiationReply(status byte)`, in server + * `func (r *UserPassNegotiationReply) WriteTo(w io.Writer)`, server writes to client + * `func NewUserPassNegotiationReplyFrom(r io.Reader)`, client reads from server +* Request: + * `type Request struct` + * `func NewRequest(cmd byte, atyp byte, dstaddr []byte, dstport []byte)`, in client + * `func (r *Request) WriteTo(w io.Writer)`, client writes to server + * `func NewRequestFrom(r io.Reader)`, server reads from client + * After server gets the client's *Request, processes... +* Reply: + * `type Reply struct` + * `func NewReply(rep byte, atyp byte, bndaddr []byte, bndport []byte)`, in server + * `func (r *Reply) WriteTo(w io.Writer)`, server writes to client + * `func NewReplyFrom(r io.Reader)`, client reads from server +* Datagram: + * `type Datagram struct` + * `func NewDatagram(atyp byte, dstaddr []byte, dstport []byte, data []byte)` + * `func NewDatagramFromBytes(bb []byte)` + * `func (d *Datagram) Bytes()` + +### Advanced API + +**Server**. You can process client's request by yourself after reading **Request** from client. Also, here is a advanced interfaces. + +* `type Server struct` +* `type Handler interface` + * `TCPHandle(*Server, *net.TCPConn, *Request) error` + * `UDPHandle(*Server, *net.UDPAddr, *Datagram) error` + +Example: + +``` +s, _ := NewClassicServer(addr, ip, username, password, tcpTimeout, udpTimeout) +s.ListenAndServe(Handler) +``` + +* If you want a standard socks5 server, pass in nil +* If you want to handle data by yourself, pass in a custom Handler + +**Client**. Here is a client support both TCP and UDP and return net.Conn. + +* `type Client struct` + +Example: + +``` +c, _ := socks5.NewClient(server, username, password, tcpTimeout, udpTimeout) +conn, _ := c.Dial(network, addr) +``` + +### Users: + + * Brook [https://github.com/txthinking/brook](https://github.com/txthinking/brook) + * Shiliew [https://www.shiliew.com](https://www.shiliew.com) + +## License + +Licensed under The MIT License diff --git a/vendor/github.com/txthinking/socks5/README_ZH.md b/vendor/github.com/txthinking/socks5/README_ZH.md new file mode 100644 index 00000000..8ac3ae30 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/README_ZH.md @@ -0,0 +1,96 @@ +## socks5 + +[English](README.md) + +[![Go Report Card](https://goreportcard.com/badge/github.com/txthinking/socks5)](https://goreportcard.com/report/github.com/txthinking/socks5) +[![GoDoc](https://godoc.org/github.com/txthinking/socks5?status.svg)](https://godoc.org/github.com/txthinking/socks5) +[![捐赠](https://img.shields.io/badge/%E6%94%AF%E6%8C%81-%E6%8D%90%E8%B5%A0-ff69b4.svg)](https://github.com/sponsors/txthinking) +[![交流群](https://img.shields.io/badge/%E7%94%B3%E8%AF%B7%E5%8A%A0%E5%85%A5-%E4%BA%A4%E6%B5%81%E7%BE%A4-ff69b4.svg)](https://docs.google.com/forms/d/e/1FAIpQLSdzMwPtDue3QoezXSKfhW88BXp57wkbDXnLaqokJqLeSWP9vQ/viewform) + +SOCKS Protocol Version 5 Library. + +完整 TCP/UDP 和 IPv4/IPv6 支持. +目标: KISS, less is more, small API, code is like the original protocol. + +❤️ A project by [txthinking.com](https://www.txthinking.com) + +### 获取 +``` +$ go get github.com/txthinking/socks5 +``` + +### Struct的概念 对标 原始协议里的概念 + +* Negotiation: + * `type NegotiationRequest struct` + * `func NewNegotiationRequest(methods []byte)`, in client + * `func (r *NegotiationRequest) WriteTo(w io.Writer)`, client writes to server + * `func NewNegotiationRequestFrom(r io.Reader)`, server reads from client + * `type NegotiationReply struct` + * `func NewNegotiationReply(method byte)`, in server + * `func (r *NegotiationReply) WriteTo(w io.Writer)`, server writes to client + * `func NewNegotiationReplyFrom(r io.Reader)`, client reads from server +* User and password negotiation: + * `type UserPassNegotiationRequest struct` + * `func NewUserPassNegotiationRequest(username []byte, password []byte)`, in client + * `func (r *UserPassNegotiationRequest) WriteTo(w io.Writer)`, client writes to server + * `func NewUserPassNegotiationRequestFrom(r io.Reader)`, server reads from client + * `type UserPassNegotiationReply struct` + * `func NewUserPassNegotiationReply(status byte)`, in server + * `func (r *UserPassNegotiationReply) WriteTo(w io.Writer)`, server writes to client + * `func NewUserPassNegotiationReplyFrom(r io.Reader)`, client reads from server +* Request: + * `type Request struct` + * `func NewRequest(cmd byte, atyp byte, dstaddr []byte, dstport []byte)`, in client + * `func (r *Request) WriteTo(w io.Writer)`, client writes to server + * `func NewRequestFrom(r io.Reader)`, server reads from client + * After server gets the client's *Request, processes... +* Reply: + * `type Reply struct` + * `func NewReply(rep byte, atyp byte, bndaddr []byte, bndport []byte)`, in server + * `func (r *Reply) WriteTo(w io.Writer)`, server writes to client + * `func NewReplyFrom(r io.Reader)`, client reads from server +* Datagram: + * `type Datagram struct` + * `func NewDatagram(atyp byte, dstaddr []byte, dstport []byte, data []byte)` + * `func NewDatagramFromBytes(bb []byte)` + * `func (d *Datagram) Bytes()` + +### 高级 API + +**Server**. 你可以自己处理client请求在读取**Request**后. 同时, 这里有一个高级接口 + +* `type Server struct` +* `type Handler interface` + * `TCPHandle(*Server, *net.TCPConn, *Request) error` + * `UDPHandle(*Server, *net.UDPAddr, *Datagram) error` + +举例: + +``` +s, _ := NewClassicServer(addr, ip, username, password, tcpTimeout, udpTimeout) +s.ListenAndServe(Handler) +``` + +* 如果你想要一个标准socks5 server, 传入nil即可 +* 如果你想要自己处理请求, 传入一个你自己的Handler + +**Client**. 这里有个socks5 client, 支持TCP和UDP, 返回net.Conn. + +* `type Client struct` + +举例: + +``` +c, _ := socks5.NewClient(server, username, password, tcpTimeout, udpTimeout) +conn, _ := c.Dial(network, addr) +``` + +### 用户: + + * Brook [https://github.com/txthinking/brook](https://github.com/txthinking/brook) + * Shiliew [https://www.shiliew.com](https://www.shiliew.com) + +## 开源协议 + +基于 MIT 协议开源 diff --git a/vendor/github.com/txthinking/socks5/bind.go b/vendor/github.com/txthinking/socks5/bind.go new file mode 100644 index 00000000..429949cc --- /dev/null +++ b/vendor/github.com/txthinking/socks5/bind.go @@ -0,0 +1,11 @@ +package socks5 + +import ( + "errors" + "net" +) + +// TODO +func (r *Request) bind(c net.Conn) error { + return errors.New("Unsupport BIND now") +} diff --git a/vendor/github.com/txthinking/socks5/client.go b/vendor/github.com/txthinking/socks5/client.go new file mode 100644 index 00000000..0dca3fc2 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/client.go @@ -0,0 +1,287 @@ +package socks5 + +import ( + "errors" + "net" + "time" +) + +// Client is socks5 client wrapper +type Client struct { + Server string + UserName string + Password string + // On cmd UDP, let server control the tcp and udp connection relationship + TCPConn *net.TCPConn + UDPConn *net.UDPConn + RemoteAddress net.Addr + TCPTimeout int + UDPTimeout int + // HijackServerUDPAddr can let client control which server UDP address to connect to after sending request, + // In most cases, you should ignore this, according to the standard server will return the address in reply, + // More: https://github.com/txthinking/socks5/pull/8. + HijackServerUDPAddr func(*Reply) (*net.UDPAddr, error) +} + +// This is just create a client, you need to use Dial to create conn +func NewClient(addr, username, password string, tcpTimeout, udpTimeout int) (*Client, error) { + c := &Client{ + Server: addr, + UserName: username, + Password: password, + TCPTimeout: tcpTimeout, + UDPTimeout: udpTimeout, + } + return c, nil +} + +func (c *Client) Dial(network, addr string) (net.Conn, error) { + return c.DialWithLocalAddr(network, "", addr, nil) +} + +func (c *Client) DialWithLocalAddr(network, src, dst string, remoteAddr net.Addr) (net.Conn, error) { + c = &Client{ + Server: c.Server, + UserName: c.UserName, + Password: c.Password, + TCPTimeout: c.TCPTimeout, + UDPTimeout: c.UDPTimeout, + RemoteAddress: remoteAddr, + HijackServerUDPAddr: c.HijackServerUDPAddr, + } + var err error + if network == "tcp" { + if c.RemoteAddress == nil { + c.RemoteAddress, err = net.ResolveTCPAddr("tcp", dst) + if err != nil { + return nil, err + } + } + var la *net.TCPAddr + if src != "" { + la, err = net.ResolveTCPAddr("tcp", src) + if err != nil { + return nil, err + } + } + if err := c.Negotiate(la); err != nil { + return nil, err + } + a, h, p, err := ParseAddress(dst) + if err != nil { + return nil, err + } + if a == ATYPDomain { + h = h[1:] + } + if _, err := c.Request(NewRequest(CmdConnect, a, h, p)); err != nil { + return nil, err + } + return c, nil + } + if network == "udp" { + if c.RemoteAddress == nil { + c.RemoteAddress, err = net.ResolveUDPAddr("udp", dst) + if err != nil { + return nil, err + } + } + var la *net.TCPAddr + if src != "" { + la, err = net.ResolveTCPAddr("tcp", src) + if err != nil { + return nil, err + } + } + if err := c.Negotiate(la); err != nil { + return nil, err + } + + var laddr *net.UDPAddr + if src != "" { + laddr, err = net.ResolveUDPAddr("udp", src) + if err != nil { + return nil, err + } + } + if src == "" { + laddr = &net.UDPAddr{ + IP: c.TCPConn.LocalAddr().(*net.TCPAddr).IP, + Port: c.TCPConn.LocalAddr().(*net.TCPAddr).Port, + Zone: c.TCPConn.LocalAddr().(*net.TCPAddr).Zone, + } + } + a, h, p, err := ParseAddress(laddr.String()) + if err != nil { + return nil, err + } + rp, err := c.Request(NewRequest(CmdUDP, a, h, p)) + if err != nil { + return nil, err + } + var raddr *net.UDPAddr + if c.HijackServerUDPAddr == nil { + raddr, err = net.ResolveUDPAddr("udp", rp.Address()) + if err != nil { + return nil, err + } + } + if c.HijackServerUDPAddr != nil { + raddr, err = c.HijackServerUDPAddr(rp) + if err != nil { + return nil, err + } + } + c.UDPConn, err = Dial.DialUDP("udp", laddr, raddr) + if err != nil { + return nil, err + } + if c.UDPTimeout != 0 { + if err := c.UDPConn.SetDeadline(time.Now().Add(time.Duration(c.UDPTimeout) * time.Second)); err != nil { + return nil, err + } + } + return c, nil + } + return nil, errors.New("unsupport network") +} + +func (c *Client) Read(b []byte) (int, error) { + if c.UDPConn == nil { + return c.TCPConn.Read(b) + } + n, err := c.UDPConn.Read(b) + if err != nil { + return 0, err + } + d, err := NewDatagramFromBytes(b[0:n]) + if err != nil { + return 0, err + } + n = copy(b, d.Data) + return n, nil +} + +func (c *Client) Write(b []byte) (int, error) { + if c.UDPConn == nil { + return c.TCPConn.Write(b) + } + a, h, p, err := ParseAddress(c.RemoteAddress.String()) + if err != nil { + return 0, err + } + if a == ATYPDomain { + h = h[1:] + } + d := NewDatagram(a, h, p, b) + b1 := d.Bytes() + n, err := c.UDPConn.Write(b1) + if err != nil { + return 0, err + } + if len(b1) != n { + return 0, errors.New("not write full") + } + return len(b), nil +} + +func (c *Client) Close() error { + if c.UDPConn == nil { + return c.TCPConn.Close() + } + if c.TCPConn != nil { + c.TCPConn.Close() + } + return c.UDPConn.Close() +} + +func (c *Client) LocalAddr() net.Addr { + if c.UDPConn == nil { + return c.TCPConn.LocalAddr() + } + return c.UDPConn.LocalAddr() +} + +func (c *Client) RemoteAddr() net.Addr { + return c.RemoteAddress +} + +func (c *Client) SetDeadline(t time.Time) error { + if c.UDPConn == nil { + return c.TCPConn.SetDeadline(t) + } + return c.UDPConn.SetDeadline(t) +} + +func (c *Client) SetReadDeadline(t time.Time) error { + if c.UDPConn == nil { + return c.TCPConn.SetReadDeadline(t) + } + return c.UDPConn.SetReadDeadline(t) +} + +func (c *Client) SetWriteDeadline(t time.Time) error { + if c.UDPConn == nil { + return c.TCPConn.SetWriteDeadline(t) + } + return c.UDPConn.SetWriteDeadline(t) +} + +func (c *Client) Negotiate(laddr *net.TCPAddr) error { + raddr, err := net.ResolveTCPAddr("tcp", c.Server) + if err != nil { + return err + } + c.TCPConn, err = Dial.DialTCP("tcp", laddr, raddr) + if err != nil { + return err + } + if c.TCPTimeout != 0 { + if err := c.TCPConn.SetDeadline(time.Now().Add(time.Duration(c.TCPTimeout) * time.Second)); err != nil { + return err + } + } + m := MethodNone + if c.UserName != "" && c.Password != "" { + m = MethodUsernamePassword + } + rq := NewNegotiationRequest([]byte{m}) + if _, err := rq.WriteTo(c.TCPConn); err != nil { + return err + } + rp, err := NewNegotiationReplyFrom(c.TCPConn) + if err != nil { + return err + } + if rp.Method != m { + return errors.New("Unsupport method") + } + if m == MethodUsernamePassword { + urq := NewUserPassNegotiationRequest([]byte(c.UserName), []byte(c.Password)) + if _, err := urq.WriteTo(c.TCPConn); err != nil { + return err + } + urp, err := NewUserPassNegotiationReplyFrom(c.TCPConn) + if err != nil { + return err + } + if urp.Status != UserPassStatusSuccess { + return ErrUserPassAuth + } + } + return nil +} + +func (c *Client) Request(r *Request) (*Reply, error) { + if _, err := r.WriteTo(c.TCPConn); err != nil { + return nil, err + } + rp, err := NewReplyFrom(c.TCPConn) + if err != nil { + return nil, err + } + if rp.Rep != RepSuccess { + return nil, errors.New("Host unreachable") + } + return rp, nil +} diff --git a/vendor/github.com/txthinking/socks5/client_side.go b/vendor/github.com/txthinking/socks5/client_side.go new file mode 100644 index 00000000..ff80ff6e --- /dev/null +++ b/vendor/github.com/txthinking/socks5/client_side.go @@ -0,0 +1,213 @@ +package socks5 + +import ( + "errors" + "io" + "log" +) + +var ( + // ErrBadReply is the error when read reply + ErrBadReply = errors.New("Bad Reply") +) + +// NewNegotiationRequest return negotiation request packet can be writed into server +func NewNegotiationRequest(methods []byte) *NegotiationRequest { + return &NegotiationRequest{ + Ver: Ver, + NMethods: byte(len(methods)), + Methods: methods, + } +} + +// WriteTo write negotiation request packet into server +func (r *NegotiationRequest) WriteTo(w io.Writer) (int64, error) { + var n int + i, err := w.Write([]byte{r.Ver}) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write([]byte{r.NMethods}) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write(r.Methods) + n = n + i + if err != nil { + return int64(n), err + } + if Debug { + log.Printf("Sent NegotiationRequest: %#v %#v %#v\n", r.Ver, r.NMethods, r.Methods) + } + return int64(n), nil +} + +// NewNegotiationReplyFrom read negotiation reply packet from server +func NewNegotiationReplyFrom(r io.Reader) (*NegotiationReply, error) { + bb := make([]byte, 2) + if _, err := io.ReadFull(r, bb); err != nil { + return nil, err + } + if bb[0] != Ver { + return nil, ErrVersion + } + if Debug { + log.Printf("Got NegotiationReply: %#v %#v\n", bb[0], bb[1]) + } + return &NegotiationReply{ + Ver: bb[0], + Method: bb[1], + }, nil +} + +// NewUserPassNegotiationRequest return user password negotiation request packet can be writed into server +func NewUserPassNegotiationRequest(username []byte, password []byte) *UserPassNegotiationRequest { + return &UserPassNegotiationRequest{ + Ver: UserPassVer, + Ulen: byte(len(username)), + Uname: username, + Plen: byte(len(password)), + Passwd: password, + } +} + +// WriteTo write user password negotiation request packet into server +func (r *UserPassNegotiationRequest) WriteTo(w io.Writer) (int64, error) { + var n int + i, err := w.Write([]byte{r.Ver, r.Ulen}) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write(r.Uname) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write([]byte{r.Plen}) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write(r.Passwd) + n = n + i + if err != nil { + return int64(n), err + } + if Debug { + log.Printf("Sent UserNameNegotiationRequest: %#v %#v %#v %#v %#v\n", r.Ver, r.Ulen, r.Uname, r.Plen, r.Passwd) + } + return int64(n), nil +} + +// NewUserPassNegotiationReplyFrom read user password negotiation reply packet from server +func NewUserPassNegotiationReplyFrom(r io.Reader) (*UserPassNegotiationReply, error) { + bb := make([]byte, 2) + if _, err := io.ReadFull(r, bb); err != nil { + return nil, err + } + if bb[0] != UserPassVer { + return nil, ErrUserPassVersion + } + if Debug { + log.Printf("Got UserPassNegotiationReply: %#v %#v \n", bb[0], bb[1]) + } + return &UserPassNegotiationReply{ + Ver: bb[0], + Status: bb[1], + }, nil +} + +// NewRequest return request packet can be writed into server, dstaddr should not have domain length +func NewRequest(cmd byte, atyp byte, dstaddr []byte, dstport []byte) *Request { + if atyp == ATYPDomain { + dstaddr = append([]byte{byte(len(dstaddr))}, dstaddr...) + } + return &Request{ + Ver: Ver, + Cmd: cmd, + Rsv: 0x00, + Atyp: atyp, + DstAddr: dstaddr, + DstPort: dstport, + } +} + +// WriteTo write request packet into server +func (r *Request) WriteTo(w io.Writer) (int64, error) { + var n int + i, err := w.Write([]byte{r.Ver, r.Cmd, r.Rsv, r.Atyp}) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write(r.DstAddr) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write(r.DstPort) + n = n + i + if err != nil { + return int64(n), err + } + if Debug { + log.Printf("Sent Request: %#v %#v %#v %#v %#v %#v\n", r.Ver, r.Cmd, r.Rsv, r.Atyp, r.DstAddr, r.DstPort) + } + return int64(n), nil +} + +// NewReplyFrom read reply packet from server +func NewReplyFrom(r io.Reader) (*Reply, error) { + bb := make([]byte, 4) + if _, err := io.ReadFull(r, bb); err != nil { + return nil, err + } + if bb[0] != Ver { + return nil, ErrVersion + } + var addr []byte + if bb[3] == ATYPIPv4 { + addr = make([]byte, 4) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + } else if bb[3] == ATYPIPv6 { + addr = make([]byte, 16) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + } else if bb[3] == ATYPDomain { + dal := make([]byte, 1) + if _, err := io.ReadFull(r, dal); err != nil { + return nil, err + } + if dal[0] == 0 { + return nil, ErrBadReply + } + addr = make([]byte, int(dal[0])) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + addr = append(dal, addr...) + } else { + return nil, ErrBadReply + } + port := make([]byte, 2) + if _, err := io.ReadFull(r, port); err != nil { + return nil, err + } + if Debug { + log.Printf("Got Reply: %#v %#v %#v %#v %#v %#v\n", bb[0], bb[1], bb[2], bb[3], addr, port) + } + return &Reply{ + Ver: bb[0], + Rep: bb[1], + Rsv: bb[2], + Atyp: bb[3], + BndAddr: addr, + BndPort: port, + }, nil +} diff --git a/vendor/github.com/txthinking/socks5/connect.go b/vendor/github.com/txthinking/socks5/connect.go new file mode 100644 index 00000000..11188335 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/connect.go @@ -0,0 +1,49 @@ +package socks5 + +import ( + "io" + "log" + "net" +) + +// Connect remote conn which u want to connect with your dialer +// Error or OK both replied. +func (r *Request) Connect(w io.Writer) (*net.TCPConn, error) { + if Debug { + log.Println("Call:", r.Address()) + } + tmp, err := Dial.Dial("tcp", r.Address()) + if err != nil { + var p *Reply + if r.Atyp == ATYPIPv4 || r.Atyp == ATYPDomain { + p = NewReply(RepHostUnreachable, ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) + } else { + p = NewReply(RepHostUnreachable, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) + } + if _, err := p.WriteTo(w); err != nil { + return nil, err + } + return nil, err + } + rc := tmp.(*net.TCPConn) + + a, addr, port, err := ParseAddress(rc.LocalAddr().String()) + if err != nil { + var p *Reply + if r.Atyp == ATYPIPv4 || r.Atyp == ATYPDomain { + p = NewReply(RepHostUnreachable, ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) + } else { + p = NewReply(RepHostUnreachable, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) + } + if _, err := p.WriteTo(w); err != nil { + return nil, err + } + return nil, err + } + p := NewReply(RepSuccess, a, addr, port) + if _, err := p.WriteTo(w); err != nil { + return nil, err + } + + return rc, nil +} diff --git a/vendor/github.com/txthinking/socks5/example_test.go b/vendor/github.com/txthinking/socks5/example_test.go new file mode 100644 index 00000000..28fe58d2 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/example_test.go @@ -0,0 +1,73 @@ +package socks5_test + +import ( + "encoding/hex" + "io/ioutil" + "log" + "net" + "net/http" + + "github.com/txthinking/socks5" +) + +func ExampleServer() { + s, err := socks5.NewClassicServer("127.0.0.1:1080", "127.0.0.1", "", "", 0, 60) + if err != nil { + panic(err) + } + // You can pass in custom Handler + s.ListenAndServe(nil) + // #Output: +} + +func ExampleClient_tcp() { + c, err := socks5.NewClient("127.0.0.1:1080", "", "", 0, 60) + if err != nil { + panic(err) + } + client := &http.Client{ + Transport: &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { + return c.Dial(network, addr) + }, + }, + } + res, err := client.Get("https://ifconfig.co") + if err != nil { + panic(err) + } + defer res.Body.Close() + b, err := ioutil.ReadAll(res.Body) + if err != nil { + panic(err) + } + log.Println(string(b)) + // Output: +} + +func ExampleClient_udp() { + c, err := socks5.NewClient("127.0.0.1:1080", "", "", 0, 60) + if err != nil { + panic(err) + } + conn, err := c.Dial("udp", "8.8.8.8:53") + if err != nil { + panic(err) + } + b, err := hex.DecodeString("0001010000010000000000000a74787468696e6b696e6703636f6d0000010001") + if err != nil { + panic(err) + } + if _, err := conn.Write(b); err != nil { + panic(err) + } + b = make([]byte, 2048) + n, err := conn.Read(b) + if err != nil { + panic(err) + } + b = b[:n] + b = b[len(b)-4:] + log.Println(net.IPv4(b[0], b[1], b[2], b[3])) + // Output: +} diff --git a/vendor/github.com/txthinking/socks5/go.mod b/vendor/github.com/txthinking/socks5/go.mod new file mode 100644 index 00000000..fcf699cf --- /dev/null +++ b/vendor/github.com/txthinking/socks5/go.mod @@ -0,0 +1,9 @@ +module github.com/txthinking/socks5 + +go 1.16 + +require ( + github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf + github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe +) diff --git a/vendor/github.com/txthinking/socks5/go.sum b/vendor/github.com/txthinking/socks5/go.sum new file mode 100644 index 00000000..5c8da0db --- /dev/null +++ b/vendor/github.com/txthinking/socks5/go.sum @@ -0,0 +1,6 @@ +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf h1:7PflaKRtU4np/epFxRXlFhlzLXZzKFrH5/I4so5Ove0= +github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf/go.mod h1:CLUSJbazqETbaR+i0YAhXBICV9TrKH93pziccMhmhpM= +github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe h1:gMWxZxBFRAXqoGkwkYlPX2zvyyKNWJpxOxCrjqJkm5A= +github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe/go.mod h1:WgqbSEmUYSjEV3B1qmee/PpP2NYEz4bL9/+mF1ma+s4= diff --git a/vendor/github.com/txthinking/socks5/init.go b/vendor/github.com/txthinking/socks5/init.go new file mode 100644 index 00000000..0febc7fa --- /dev/null +++ b/vendor/github.com/txthinking/socks5/init.go @@ -0,0 +1,13 @@ +package socks5 + +import ( + "github.com/txthinking/x" +) + +// Debug enable debug log +var Debug bool +var Dial x.Dialer = x.DefaultDial + +func init() { + // log.SetFlags(log.LstdFlags | log.Lshortfile) +} diff --git a/vendor/github.com/txthinking/socks5/server.go b/vendor/github.com/txthinking/socks5/server.go new file mode 100644 index 00000000..fc42682d --- /dev/null +++ b/vendor/github.com/txthinking/socks5/server.go @@ -0,0 +1,453 @@ +package socks5 + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "strings" + "time" + + cache "github.com/patrickmn/go-cache" + "github.com/txthinking/runnergroup" +) + +var ( + // ErrUnsupportCmd is the error when got unsupport command + ErrUnsupportCmd = errors.New("Unsupport Command") + // ErrUserPassAuth is the error when got invalid username or password + ErrUserPassAuth = errors.New("Invalid Username or Password for Auth") +) + +// Server is socks5 server wrapper +type Server struct { + UserName string + Password string + Method byte + SupportedCommands []byte + TCPAddr *net.TCPAddr + UDPAddr *net.UDPAddr + ServerAddr *net.UDPAddr + TCPListen *net.TCPListener + UDPConn *net.UDPConn + UDPExchanges *cache.Cache + TCPTimeout int + UDPTimeout int + Handle Handler + AssociatedUDP *cache.Cache + RunnerGroup *runnergroup.RunnerGroup + // RFC: [UDP ASSOCIATE] The server MAY use this information to limit access to the association. Default false, no limit. + LimitUDP bool +} + +// UDPExchange used to store client address and remote connection +type UDPExchange struct { + ClientAddr *net.UDPAddr + RemoteConn *net.UDPConn +} + +// NewClassicServer return a server which allow none method +func NewClassicServer(addr, ip, username, password string, tcpTimeout, udpTimeout int) (*Server, error) { + _, p, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + taddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + uaddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + saddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(ip, p)) + if err != nil { + return nil, err + } + m := MethodNone + if username != "" && password != "" { + m = MethodUsernamePassword + } + cs := cache.New(cache.NoExpiration, cache.NoExpiration) + cs1 := cache.New(cache.NoExpiration, cache.NoExpiration) + s := &Server{ + Method: m, + UserName: username, + Password: password, + SupportedCommands: []byte{CmdConnect, CmdUDP}, + TCPAddr: taddr, + UDPAddr: uaddr, + ServerAddr: saddr, + UDPExchanges: cs, + TCPTimeout: tcpTimeout, + UDPTimeout: udpTimeout, + AssociatedUDP: cs1, + RunnerGroup: runnergroup.New(), + } + return s, nil +} + +// Negotiate handle negotiate packet. +// This method do not handle gssapi(0x01) method now. +// Error or OK both replied. +func (s *Server) Negotiate(rw io.ReadWriter) error { + rq, err := NewNegotiationRequestFrom(rw) + if err != nil { + return err + } + var got bool + var m byte + for _, m = range rq.Methods { + if m == s.Method { + got = true + } + } + if !got { + rp := NewNegotiationReply(MethodUnsupportAll) + if _, err := rp.WriteTo(rw); err != nil { + return err + } + } + rp := NewNegotiationReply(s.Method) + if _, err := rp.WriteTo(rw); err != nil { + return err + } + + if s.Method == MethodUsernamePassword { + urq, err := NewUserPassNegotiationRequestFrom(rw) + if err != nil { + return err + } + if string(urq.Uname) != s.UserName || string(urq.Passwd) != s.Password { + urp := NewUserPassNegotiationReply(UserPassStatusFailure) + if _, err := urp.WriteTo(rw); err != nil { + return err + } + return ErrUserPassAuth + } + urp := NewUserPassNegotiationReply(UserPassStatusSuccess) + if _, err := urp.WriteTo(rw); err != nil { + return err + } + } + return nil +} + +// GetRequest get request packet from client, and check command according to SupportedCommands +// Error replied. +func (s *Server) GetRequest(rw io.ReadWriter) (*Request, error) { + r, err := NewRequestFrom(rw) + if err != nil { + return nil, err + } + var supported bool + for _, c := range s.SupportedCommands { + if r.Cmd == c { + supported = true + break + } + } + if !supported { + var p *Reply + if r.Atyp == ATYPIPv4 || r.Atyp == ATYPDomain { + p = NewReply(RepCommandNotSupported, ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) + } else { + p = NewReply(RepCommandNotSupported, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) + } + if _, err := p.WriteTo(rw); err != nil { + return nil, err + } + return nil, ErrUnsupportCmd + } + return r, nil +} + +// Run server +func (s *Server) ListenAndServe(h Handler) error { + if h == nil { + s.Handle = &DefaultHandle{} + } else { + s.Handle = h + } + s.RunnerGroup.Add(&runnergroup.Runner{ + Start: func() error { + return s.RunTCPServer() + }, + Stop: func() error { + if s.TCPListen != nil { + return s.TCPListen.Close() + } + return nil + }, + }) + s.RunnerGroup.Add(&runnergroup.Runner{ + Start: func() error { + return s.RunUDPServer() + }, + Stop: func() error { + if s.UDPConn != nil { + return s.UDPConn.Close() + } + return nil + }, + }) + return s.RunnerGroup.Wait() +} + +// RunTCPServer starts tcp server +func (s *Server) RunTCPServer() error { + var err error + s.TCPListen, err = net.ListenTCP("tcp", s.TCPAddr) + if err != nil { + return err + } + defer s.TCPListen.Close() + for { + c, err := s.TCPListen.AcceptTCP() + if err != nil { + return err + } + go func(c *net.TCPConn) { + defer c.Close() + if s.TCPTimeout != 0 { + if err := c.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil { + log.Println(err) + return + } + } + if err := s.Negotiate(c); err != nil { + log.Println(err) + return + } + r, err := s.GetRequest(c) + if err != nil { + log.Println(err) + return + } + if err := s.Handle.TCPHandle(s, c, r); err != nil { + log.Println(err) + } + }(c) + } + return nil +} + +// RunUDPServer starts udp server +func (s *Server) RunUDPServer() error { + var err error + s.UDPConn, err = net.ListenUDP("udp", s.UDPAddr) + if err != nil { + return err + } + defer s.UDPConn.Close() + for { + b := make([]byte, 65507) + n, addr, err := s.UDPConn.ReadFromUDP(b) + if err != nil { + return err + } + go func(addr *net.UDPAddr, b []byte) { + d, err := NewDatagramFromBytes(b) + if err != nil { + log.Println(err) + return + } + if d.Frag != 0x00 { + log.Println("Ignore frag", d.Frag) + return + } + if err := s.Handle.UDPHandle(s, addr, d); err != nil { + log.Println(err) + return + } + }(addr, b[0:n]) + } + return nil +} + +// Stop server +func (s *Server) Shutdown() error { + return s.RunnerGroup.Done() +} + +// Handler handle tcp, udp request +type Handler interface { + // Request has not been replied yet + TCPHandle(*Server, *net.TCPConn, *Request) error + UDPHandle(*Server, *net.UDPAddr, *Datagram) error +} + +// DefaultHandle implements Handler interface +type DefaultHandle struct { +} + +// TCPHandle auto handle request. You may prefer to do yourself. +func (h *DefaultHandle) TCPHandle(s *Server, c *net.TCPConn, r *Request) error { + if r.Cmd == CmdConnect { + rc, err := r.Connect(c) + if err != nil { + return err + } + defer rc.Close() + go func() { + var bf [1024 * 2]byte + for { + if s.TCPTimeout != 0 { + if err := rc.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil { + return + } + } + i, err := rc.Read(bf[:]) + if err != nil { + return + } + if _, err := c.Write(bf[0:i]); err != nil { + return + } + } + }() + var bf [1024 * 2]byte + for { + if s.TCPTimeout != 0 { + if err := c.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil { + return nil + } + } + i, err := c.Read(bf[:]) + if err != nil { + return nil + } + if _, err := rc.Write(bf[0:i]); err != nil { + return nil + } + } + return nil + } + if r.Cmd == CmdUDP { + caddr, err := r.UDP(c, s.ServerAddr) + if err != nil { + return err + } + ch := make(chan byte) + defer close(ch) + s.AssociatedUDP.Set(caddr.String(), ch, -1) + defer s.AssociatedUDP.Delete(caddr.String()) + io.Copy(ioutil.Discard, c) + if Debug { + log.Printf("A tcp connection that udp %#v associated closed\n", caddr.String()) + } + return nil + } + return ErrUnsupportCmd +} + +// UDPHandle auto handle packet. You may prefer to do yourself. +func (h *DefaultHandle) UDPHandle(s *Server, addr *net.UDPAddr, d *Datagram) error { + src := addr.String() + var ch chan byte + if s.LimitUDP { + any, ok := s.AssociatedUDP.Get(src) + if !ok { + return fmt.Errorf("This udp address %s is not associated with tcp", src) + } + ch = any.(chan byte) + } + send := func(ue *UDPExchange, data []byte) error { + select { + case <-ch: + return fmt.Errorf("This udp address %s is not associated with tcp", src) + default: + _, err := ue.RemoteConn.Write(data) + if err != nil { + return err + } + if Debug { + log.Printf("Sent UDP data to remote. client: %#v server: %#v remote: %#v data: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), data) + } + } + return nil + } + + dst := d.Address() + var ue *UDPExchange + iue, ok := s.UDPExchanges.Get(src + dst) + if ok { + ue = iue.(*UDPExchange) + err := send(ue, d.Data) + if err == nil { + return nil + } + if !strings.Contains(err.Error(), "closed") { + return err + } + } + + if Debug { + log.Printf("Call udp: %#v\n", dst) + } + raddr, err := net.ResolveUDPAddr("udp", dst) + if err != nil { + return err + } + rc, err := Dial.DialUDP("udp", nil, raddr) + if err != nil { + return err + } + ue = &UDPExchange{ + ClientAddr: addr, + RemoteConn: rc, + } + if Debug { + log.Printf("Created remote UDP conn for client. client: %#v server: %#v remote: %#v\n", addr.String(), ue.RemoteConn.LocalAddr().String(), d.Address()) + } + if err := send(ue, d.Data); err != nil { + ue.RemoteConn.Close() + return err + } + s.UDPExchanges.Set(src+dst, ue, -1) + go func(ue *UDPExchange, dst string) { + defer func() { + ue.RemoteConn.Close() + s.UDPExchanges.Delete(ue.ClientAddr.String() + dst) + }() + var b [65507]byte + for { + select { + case <-ch: + if Debug { + log.Printf("The tcp that udp address %s associated closed\n", ue.ClientAddr.String()) + } + return + default: + if s.UDPTimeout != 0 { + if err := ue.RemoteConn.SetDeadline(time.Now().Add(time.Duration(s.UDPTimeout) * time.Second)); err != nil { + log.Println(err) + return + } + } + n, err := ue.RemoteConn.Read(b[:]) + if err != nil { + return + } + if Debug { + log.Printf("Got UDP data from remote. client: %#v server: %#v remote: %#v data: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), b[0:n]) + } + a, addr, port, err := ParseAddress(dst) + if err != nil { + log.Println(err) + return + } + d1 := NewDatagram(a, addr, port, b[0:n]) + if _, err := s.UDPConn.WriteToUDP(d1.Bytes(), ue.ClientAddr); err != nil { + return + } + if Debug { + log.Printf("Sent Datagram. client: %#v server: %#v remote: %#v data: %#v %#v %#v %#v %#v %#v datagram address: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), d1.Rsv, d1.Frag, d1.Atyp, d1.DstAddr, d1.DstPort, d1.Data, d1.Address()) + } + } + } + }(ue, dst) + return nil +} diff --git a/vendor/github.com/txthinking/socks5/server_side.go b/vendor/github.com/txthinking/socks5/server_side.go new file mode 100644 index 00000000..a851d94f --- /dev/null +++ b/vendor/github.com/txthinking/socks5/server_side.go @@ -0,0 +1,298 @@ +package socks5 + +import ( + "errors" + "io" + "log" +) + +var ( + // ErrVersion is version error + ErrVersion = errors.New("Invalid Version") + // ErrUserPassVersion is username/password auth version error + ErrUserPassVersion = errors.New("Invalid Version of Username Password Auth") + // ErrBadRequest is bad request error + ErrBadRequest = errors.New("Bad Request") +) + +// NewNegotiationRequestFrom read negotiation requst packet from client +func NewNegotiationRequestFrom(r io.Reader) (*NegotiationRequest, error) { + // memory strict + bb := make([]byte, 2) + if _, err := io.ReadFull(r, bb); err != nil { + return nil, err + } + if bb[0] != Ver { + return nil, ErrVersion + } + if bb[1] == 0 { + return nil, ErrBadRequest + } + ms := make([]byte, int(bb[1])) + if _, err := io.ReadFull(r, ms); err != nil { + return nil, err + } + if Debug { + log.Printf("Got NegotiationRequest: %#v %#v %#v\n", bb[0], bb[1], ms) + } + return &NegotiationRequest{ + Ver: bb[0], + NMethods: bb[1], + Methods: ms, + }, nil +} + +// NewNegotiationReply return negotiation reply packet can be writed into client +func NewNegotiationReply(method byte) *NegotiationReply { + return &NegotiationReply{ + Ver: Ver, + Method: method, + } +} + +// WriteTo write negotiation reply packet into client +func (r *NegotiationReply) WriteTo(w io.Writer) (int64, error) { + var n int + i, err := w.Write([]byte{r.Ver, r.Method}) + n = n + i + if err != nil { + return int64(n), err + } + if Debug { + log.Printf("Sent NegotiationReply: %#v %#v\n", r.Ver, r.Method) + } + return int64(n), nil +} + +// NewUserPassNegotiationRequestFrom read user password negotiation request packet from client +func NewUserPassNegotiationRequestFrom(r io.Reader) (*UserPassNegotiationRequest, error) { + bb := make([]byte, 2) + if _, err := io.ReadFull(r, bb); err != nil { + return nil, err + } + if bb[0] != UserPassVer { + return nil, ErrUserPassVersion + } + if bb[1] == 0 { + return nil, ErrBadRequest + } + ub := make([]byte, int(bb[1])+1) + if _, err := io.ReadFull(r, ub); err != nil { + return nil, err + } + if ub[int(bb[1])] == 0 { + return nil, ErrBadRequest + } + p := make([]byte, int(ub[int(bb[1])])) + if _, err := io.ReadFull(r, p); err != nil { + return nil, err + } + if Debug { + log.Printf("Got UserPassNegotiationRequest: %#v %#v %#v %#v %#v\n", bb[0], bb[1], ub[:int(bb[1])], ub[int(bb[1])], p) + } + return &UserPassNegotiationRequest{ + Ver: bb[0], + Ulen: bb[1], + Uname: ub[:int(bb[1])], + Plen: ub[int(bb[1])], + Passwd: p, + }, nil +} + +// NewUserPassNegotiationReply return negotiation username password reply packet can be writed into client +func NewUserPassNegotiationReply(status byte) *UserPassNegotiationReply { + return &UserPassNegotiationReply{ + Ver: UserPassVer, + Status: status, + } +} + +// WriteTo write negotiation username password reply packet into client +func (r *UserPassNegotiationReply) WriteTo(w io.Writer) (int64, error) { + var n int + i, err := w.Write([]byte{r.Ver, r.Status}) + n = n + i + if err != nil { + return int64(n), err + } + if Debug { + log.Printf("Sent UserPassNegotiationReply: %#v %#v \n", r.Ver, r.Status) + } + return int64(n), nil +} + +// NewRequestFrom read requst packet from client +func NewRequestFrom(r io.Reader) (*Request, error) { + bb := make([]byte, 4) + if _, err := io.ReadFull(r, bb); err != nil { + return nil, err + } + if bb[0] != Ver { + return nil, ErrVersion + } + var addr []byte + if bb[3] == ATYPIPv4 { + addr = make([]byte, 4) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + } else if bb[3] == ATYPIPv6 { + addr = make([]byte, 16) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + } else if bb[3] == ATYPDomain { + dal := make([]byte, 1) + if _, err := io.ReadFull(r, dal); err != nil { + return nil, err + } + if dal[0] == 0 { + return nil, ErrBadRequest + } + addr = make([]byte, int(dal[0])) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + addr = append(dal, addr...) + } else { + return nil, ErrBadRequest + } + port := make([]byte, 2) + if _, err := io.ReadFull(r, port); err != nil { + return nil, err + } + if Debug { + log.Printf("Got Request: %#v %#v %#v %#v %#v %#v\n", bb[0], bb[1], bb[2], bb[3], addr, port) + } + return &Request{ + Ver: bb[0], + Cmd: bb[1], + Rsv: bb[2], + Atyp: bb[3], + DstAddr: addr, + DstPort: port, + }, nil +} + +// NewReply return reply packet can be writed into client, bndaddr should not have domain length +func NewReply(rep byte, atyp byte, bndaddr []byte, bndport []byte) *Reply { + if atyp == ATYPDomain { + bndaddr = append([]byte{byte(len(bndaddr))}, bndaddr...) + } + return &Reply{ + Ver: Ver, + Rep: rep, + Rsv: 0x00, + Atyp: atyp, + BndAddr: bndaddr, + BndPort: bndport, + } +} + +// WriteTo write reply packet into client +func (r *Reply) WriteTo(w io.Writer) (int64, error) { + var n int + i, err := w.Write([]byte{r.Ver, r.Rep, r.Rsv, r.Atyp}) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write(r.BndAddr) + n = n + i + if err != nil { + return int64(n), err + } + i, err = w.Write(r.BndPort) + n = n + i + if err != nil { + return int64(n), err + } + if Debug { + log.Printf("Sent Reply: %#v %#v %#v %#v %#v %#v\n", r.Ver, r.Rep, r.Rsv, r.Atyp, r.BndAddr, r.BndPort) + } + return int64(n), nil +} + +func NewDatagramFromBytes(bb []byte) (*Datagram, error) { + n := len(bb) + minl := 4 + if n < minl { + return nil, ErrBadRequest + } + var addr []byte + if bb[3] == ATYPIPv4 { + minl += 4 + if n < minl { + return nil, ErrBadRequest + } + addr = bb[minl-4 : minl] + } else if bb[3] == ATYPIPv6 { + minl += 16 + if n < minl { + return nil, ErrBadRequest + } + addr = bb[minl-16 : minl] + } else if bb[3] == ATYPDomain { + minl += 1 + if n < minl { + return nil, ErrBadRequest + } + l := bb[4] + if l == 0 { + return nil, ErrBadRequest + } + minl += int(l) + if n < minl { + return nil, ErrBadRequest + } + addr = bb[minl-int(l) : minl] + addr = append([]byte{l}, addr...) + } else { + return nil, ErrBadRequest + } + minl += 2 + if n <= minl { + return nil, ErrBadRequest + } + port := bb[minl-2 : minl] + data := bb[minl:] + d := &Datagram{ + Rsv: bb[0:2], + Frag: bb[2], + Atyp: bb[3], + DstAddr: addr, + DstPort: port, + Data: data, + } + if Debug { + log.Printf("Got Datagram. data: %#v %#v %#v %#v %#v %#v datagram address: %#v\n", d.Rsv, d.Frag, d.Atyp, d.DstAddr, d.DstPort, d.Data, d.Address()) + } + return d, nil +} + +// NewDatagram return datagram packet can be writed into client, dstaddr should not have domain length +func NewDatagram(atyp byte, dstaddr []byte, dstport []byte, data []byte) *Datagram { + if atyp == ATYPDomain { + dstaddr = append([]byte{byte(len(dstaddr))}, dstaddr...) + } + return &Datagram{ + Rsv: []byte{0x00, 0x00}, + Frag: 0x00, + Atyp: atyp, + DstAddr: dstaddr, + DstPort: dstport, + Data: data, + } +} + +// Bytes return []byte +func (d *Datagram) Bytes() []byte { + b := make([]byte, 0) + b = append(b, d.Rsv...) + b = append(b, d.Frag) + b = append(b, d.Atyp) + b = append(b, d.DstAddr...) + b = append(b, d.DstPort...) + b = append(b, d.Data...) + return b +} diff --git a/vendor/github.com/txthinking/socks5/socks5.go b/vendor/github.com/txthinking/socks5/socks5.go new file mode 100644 index 00000000..d77253ba --- /dev/null +++ b/vendor/github.com/txthinking/socks5/socks5.go @@ -0,0 +1,119 @@ +package socks5 + +const ( + // Ver is socks protocol version + Ver byte = 0x05 + + // MethodNone is none method + MethodNone byte = 0x00 + // MethodGSSAPI is gssapi method + MethodGSSAPI byte = 0x01 // MUST support // todo + // MethodUsernamePassword is username/assword auth method + MethodUsernamePassword byte = 0x02 // SHOULD support + // MethodUnsupportAll means unsupport all given methods + MethodUnsupportAll byte = 0xFF + + // UserPassVer is username/password auth protocol version + UserPassVer byte = 0x01 + // UserPassStatusSuccess is success status of username/password auth + UserPassStatusSuccess byte = 0x00 + // UserPassStatusFailure is failure status of username/password auth + UserPassStatusFailure byte = 0x01 // just other than 0x00 + + // CmdConnect is connect command + CmdConnect byte = 0x01 + // CmdBind is bind command + CmdBind byte = 0x02 + // CmdUDP is UDP command + CmdUDP byte = 0x03 + + // ATYPIPv4 is ipv4 address type + ATYPIPv4 byte = 0x01 // 4 octets + // ATYPDomain is domain address type + ATYPDomain byte = 0x03 // The first octet of the address field contains the number of octets of name that follow, there is no terminating NUL octet. + // ATYPIPv6 is ipv6 address type + ATYPIPv6 byte = 0x04 // 16 octets + + // RepSuccess means that success for repling + RepSuccess byte = 0x00 + // RepServerFailure means the server failure + RepServerFailure byte = 0x01 + // RepNotAllowed means the request not allowed + RepNotAllowed byte = 0x02 + // RepNetworkUnreachable means the network unreachable + RepNetworkUnreachable byte = 0x03 + // RepHostUnreachable means the host unreachable + RepHostUnreachable byte = 0x04 + // RepConnectionRefused means the connection refused + RepConnectionRefused byte = 0x05 + // RepTTLExpired means the TTL expired + RepTTLExpired byte = 0x06 + // RepCommandNotSupported means the request command not supported + RepCommandNotSupported byte = 0x07 + // RepAddressNotSupported means the request address not supported + RepAddressNotSupported byte = 0x08 +) + +// NegotiationRequest is the negotiation reqeust packet +type NegotiationRequest struct { + Ver byte + NMethods byte + Methods []byte // 1-255 bytes +} + +// NegotiationReply is the negotiation reply packet +type NegotiationReply struct { + Ver byte + Method byte +} + +// UserPassNegotiationRequest is the negotiation username/password reqeust packet +type UserPassNegotiationRequest struct { + Ver byte + Ulen byte + Uname []byte // 1-255 bytes + Plen byte + Passwd []byte // 1-255 bytes +} + +// UserPassNegotiationReply is the negotiation username/password reply packet +type UserPassNegotiationReply struct { + Ver byte + Status byte +} + +// Request is the request packet +type Request struct { + Ver byte + Cmd byte + Rsv byte // 0x00 + Atyp byte + DstAddr []byte + DstPort []byte // 2 bytes +} + +// Reply is the reply packet +type Reply struct { + Ver byte + Rep byte + Rsv byte // 0x00 + Atyp byte + // CONNECT socks server's address which used to connect to dst addr + // BIND ... + // UDP socks server's address which used to connect to dst addr + BndAddr []byte + // CONNECT socks server's port which used to connect to dst addr + // BIND ... + // UDP socks server's port which used to connect to dst addr + BndPort []byte // 2 bytes +} + +// Datagram is the UDP packet +type Datagram struct { + Rsv []byte // 0x00 0x00 + Frag byte + Atyp byte + DstAddr []byte + DstPort []byte // 2 bytes + Data []byte +} diff --git a/vendor/github.com/txthinking/socks5/udp.go b/vendor/github.com/txthinking/socks5/udp.go new file mode 100644 index 00000000..b19160a0 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/udp.go @@ -0,0 +1,56 @@ +package socks5 + +import ( + "bytes" + "log" + "net" +) + +// UDP remote conn which u want to connect with your dialer. +// Error or OK both replied. +// Addr can be used to associate TCP connection with the coming UDP connection. +func (r *Request) UDP(c *net.TCPConn, serverAddr *net.UDPAddr) (*net.UDPAddr, error) { + var clientAddr *net.UDPAddr + var err error + if bytes.Compare(r.DstPort, []byte{0x00, 0x00}) == 0 { + // If the requested Host/Port is all zeros, the relay should simply use the Host/Port that sent the request. + // https://stackoverflow.com/questions/62283351/how-to-use-socks-5-proxy-with-tidudpclient-properly + clientAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String()) + } else { + clientAddr, err = net.ResolveUDPAddr("udp", r.Address()) + } + if err != nil { + var p *Reply + if r.Atyp == ATYPIPv4 || r.Atyp == ATYPDomain { + p = NewReply(RepHostUnreachable, ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) + } else { + p = NewReply(RepHostUnreachable, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) + } + if _, err := p.WriteTo(c); err != nil { + return nil, err + } + return nil, err + } + if Debug { + log.Println("Client wants to start UDP talk use", clientAddr.String()) + } + a, addr, port, err := ParseAddress(serverAddr.String()) + if err != nil { + var p *Reply + if r.Atyp == ATYPIPv4 || r.Atyp == ATYPDomain { + p = NewReply(RepHostUnreachable, ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) + } else { + p = NewReply(RepHostUnreachable, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) + } + if _, err := p.WriteTo(c); err != nil { + return nil, err + } + return nil, err + } + p := NewReply(RepSuccess, a, addr, port) + if _, err := p.WriteTo(c); err != nil { + return nil, err + } + + return clientAddr, nil +} diff --git a/vendor/github.com/txthinking/socks5/util.go b/vendor/github.com/txthinking/socks5/util.go new file mode 100644 index 00000000..cb586757 --- /dev/null +++ b/vendor/github.com/txthinking/socks5/util.go @@ -0,0 +1,135 @@ +package socks5 + +import ( + "bytes" + "encoding/binary" + "errors" + "net" + "strconv" +) + +// ParseAddress format address x.x.x.x:xx to raw address. +// addr contains domain length +func ParseAddress(address string) (a byte, addr []byte, port []byte, err error) { + var h, p string + h, p, err = net.SplitHostPort(address) + if err != nil { + return + } + ip := net.ParseIP(h) + if ip4 := ip.To4(); ip4 != nil { + a = ATYPIPv4 + addr = []byte(ip4) + } else if ip6 := ip.To16(); ip6 != nil { + a = ATYPIPv6 + addr = []byte(ip6) + } else { + a = ATYPDomain + addr = []byte{byte(len(h))} + addr = append(addr, []byte(h)...) + } + i, _ := strconv.Atoi(p) + port = make([]byte, 2) + binary.BigEndian.PutUint16(port, uint16(i)) + return +} + +// bytes to address +// addr contains domain length +func ParseBytesAddress(b []byte) (a byte, addr []byte, port []byte, err error) { + if len(b) < 1 { + err = errors.New("Invalid address") + return + } + a = b[0] + if a == ATYPIPv4 { + if len(b) < 1+4+2 { + err = errors.New("Invalid address") + return + } + addr = b[1 : 1+4] + port = b[1+4 : 1+4+2] + return + } + if a == ATYPIPv6 { + if len(b) < 1+16+2 { + err = errors.New("Invalid address") + return + } + addr = b[1 : 1+16] + port = b[1+16 : 1+16+2] + return + } + if a == ATYPDomain { + if len(b) < 1+1 { + err = errors.New("Invalid address") + return + } + l := int(b[1]) + if len(b) < 1+1+l+2 { + err = errors.New("Invalid address") + return + } + addr = b[1 : 1+1+l] + port = b[1+1+l : 1+1+l+2] + return + } + err = errors.New("Invalid address") + return +} + +// ToAddress format raw address to x.x.x.x:xx +// addr contains domain length +func ToAddress(a byte, addr []byte, port []byte) string { + var h, p string + if a == ATYPIPv4 || a == ATYPIPv6 { + h = net.IP(addr).String() + } + if a == ATYPDomain { + if len(addr) < 1 { + return "" + } + if len(addr) < int(addr[0])+1 { + return "" + } + h = string(addr[1:]) + } + p = strconv.Itoa(int(binary.BigEndian.Uint16(port))) + return net.JoinHostPort(h, p) +} + +// Address return request address like ip:xx +func (r *Request) Address() string { + var s string + if r.Atyp == ATYPDomain { + s = bytes.NewBuffer(r.DstAddr[1:]).String() + } else { + s = net.IP(r.DstAddr).String() + } + p := strconv.Itoa(int(binary.BigEndian.Uint16(r.DstPort))) + return net.JoinHostPort(s, p) +} + +// Address return request address like ip:xx +func (r *Reply) Address() string { + var s string + if r.Atyp == ATYPDomain { + s = bytes.NewBuffer(r.BndAddr[1:]).String() + } else { + s = net.IP(r.BndAddr).String() + } + p := strconv.Itoa(int(binary.BigEndian.Uint16(r.BndPort))) + return net.JoinHostPort(s, p) +} + +// Address return datagram address like ip:xx +func (d *Datagram) Address() string { + var s string + if d.Atyp == ATYPDomain { + s = bytes.NewBuffer(d.DstAddr[1:]).String() + } else { + s = net.IP(d.DstAddr).String() + } + p := strconv.Itoa(int(binary.BigEndian.Uint16(d.DstPort))) + return net.JoinHostPort(s, p) +} diff --git a/vendor/github.com/txthinking/socks5/util_test.go b/vendor/github.com/txthinking/socks5/util_test.go new file mode 100644 index 00000000..9a891e9c --- /dev/null +++ b/vendor/github.com/txthinking/socks5/util_test.go @@ -0,0 +1,9 @@ +package socks5 + +import "testing" + +func TestParseAddress(t *testing.T) { + t.Log(ParseAddress("127.0.0.1:80")) + t.Log(ParseAddress("[::1]:80")) + t.Log(ParseAddress("a.com:80")) +} diff --git a/vendor/github.com/txthinking/x/.github/ISSUE_TEMPLATE b/vendor/github.com/txthinking/x/.github/ISSUE_TEMPLATE new file mode 100644 index 00000000..94e28b59 --- /dev/null +++ b/vendor/github.com/txthinking/x/.github/ISSUE_TEMPLATE @@ -0,0 +1,10 @@ +#### Describe actual behavior + +#### What is your expected behavior + +#### Specifications like the version of the project, operating system, or hardware + +#### Steps to reproduce the problem +0. +1. +2. diff --git a/vendor/github.com/txthinking/x/.github/PULL_REQUEST_TEMPLATE b/vendor/github.com/txthinking/x/.github/PULL_REQUEST_TEMPLATE new file mode 100644 index 00000000..489e7e43 --- /dev/null +++ b/vendor/github.com/txthinking/x/.github/PULL_REQUEST_TEMPLATE @@ -0,0 +1,8 @@ +Fixes # . + +Changes proposed in this pull request: +- +- +- + +@mentions diff --git a/vendor/github.com/txthinking/x/.gitignore b/vendor/github.com/txthinking/x/.gitignore new file mode 100644 index 00000000..e69de29b diff --git a/vendor/github.com/txthinking/x/.travis.yml b/vendor/github.com/txthinking/x/.travis.yml new file mode 100644 index 00000000..c2a609ed --- /dev/null +++ b/vendor/github.com/txthinking/x/.travis.yml @@ -0,0 +1,6 @@ +language: go +sudo: false +go: +install: +script: + - go test -v . diff --git a/vendor/github.com/txthinking/x/LICENSE b/vendor/github.com/txthinking/x/LICENSE new file mode 100644 index 00000000..49755836 --- /dev/null +++ b/vendor/github.com/txthinking/x/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2013-present Cloud https://www.txthinking.com + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/txthinking/x/README.md b/vendor/github.com/txthinking/x/README.md new file mode 100644 index 00000000..783755f8 --- /dev/null +++ b/vendor/github.com/txthinking/x/README.md @@ -0,0 +1,18 @@ +## x +[![GoDoc](https://godoc.org/github.com/txthinking/x?status.svg)](https://godoc.org/github.com/txthinking/x) + +My util library + +### Install + +``` +$ go get github.com/txthinking/x +``` + +## Author + +A project by [txthinking](https://www.txthinking.com) + +## License + +Licensed under The MIT License diff --git a/vendor/github.com/txthinking/x/dial.go b/vendor/github.com/txthinking/x/dial.go new file mode 100644 index 00000000..0142e3d4 --- /dev/null +++ b/vendor/github.com/txthinking/x/dial.go @@ -0,0 +1,30 @@ +package x + +import ( + "net" +) + +// Dialer is a common interface for dialing +type Dialer interface { + Dial(network, addr string) (net.Conn, error) + DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) + DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) +} + +type Dial struct { +} + +// DefaultDial is the default dialer in net package +var DefaultDial = &Dial{} + +func (d *Dial) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} + +func (d *Dial) DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + return net.DialTCP(network, laddr, raddr) +} + +func (d *Dial) DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + return net.DialUDP(network, laddr, raddr) +} diff --git a/vendor/github.com/txthinking/x/geo.go b/vendor/github.com/txthinking/x/geo.go new file mode 100644 index 00000000..fb15750f --- /dev/null +++ b/vendor/github.com/txthinking/x/geo.go @@ -0,0 +1,16 @@ +package x + +import "math" + +// return m +func Distance(lng1, lat1, lng2, lat2 float64) float64 { + radius := 6371000.0 + rad := math.Pi / 180.0 + lng1 = lng1 * rad + lat1 = lat1 * rad + lng2 = lng2 * rad + lat2 = lat2 * rad + theta := lng2 - lng1 + dist := math.Acos(math.Sin(lat1)*math.Sin(lat2) + math.Cos(lat1)*math.Cos(lat2)*math.Cos(theta)) + return dist * radius +} diff --git a/vendor/github.com/txthinking/x/http.go b/vendor/github.com/txthinking/x/http.go new file mode 100644 index 00000000..57cf40d1 --- /dev/null +++ b/vendor/github.com/txthinking/x/http.go @@ -0,0 +1,103 @@ +package x + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "path/filepath" +) + +// MultipartFormDataFromFile generate multipart form data according to RFC 2388. +// files is the paths of your files +func MultipartFormDataFromFile(params, files map[string][]string, boundary string) (ior io.Reader, err error) { + var bs []byte + bf := &bytes.Buffer{} + + // prepare common value + var name, value string + var values []string + for name, values = range params { + for _, value = range values { + bf.WriteString(fmt.Sprintf("--%s\r\n", boundary)) + bf.WriteString(fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"\r\n\r\n", name)) + bf.WriteString(fmt.Sprintf("%s\r\n", value)) + } + } + + for name, values = range files { + for _, value = range values { + bf.WriteString(fmt.Sprintf("--%s\r\n", boundary)) + bf.WriteString(fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\n", name, filepath.Base(value))) + bf.WriteString(fmt.Sprintf("Content-Type: application/octet-stream\r\n\r\n")) + bs, err = ioutil.ReadFile(value) + if err != nil { + return + } + bf.Write(bs) + bf.WriteString("\r\n") + } + } + bf.WriteString(fmt.Sprintf("--%s--\r\n", boundary)) + ior = bf + return +} + +// MultipartFormDataFromReader generate multipart form data according to RFC 2388. +func MultipartFormDataFromReader(params map[string][]string, files map[string][]io.Reader, boundary string) (ior io.Reader, err error) { + var bs []byte + bf := &bytes.Buffer{} + + // prepare common value + var name, value string + var values []string + for name, values = range params { + for _, value = range values { + bf.WriteString(fmt.Sprintf("--%s\r\n", boundary)) + bf.WriteString(fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"\r\n\r\n", name)) + bf.WriteString(fmt.Sprintf("%s\r\n", value)) + } + } + + var rs []io.Reader + var r io.Reader + for name, rs = range files { + for _, r = range rs { + bf.WriteString(fmt.Sprintf("--%s\r\n", boundary)) + bf.WriteString(fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\n", name, "-")) + bf.WriteString(fmt.Sprintf("Content-Type: application/octet-stream\r\n\r\n")) + bs, err = ioutil.ReadAll(r) + if err != nil { + return + } + bf.Write(bs) + bf.WriteString("\r\n") + } + } + bf.WriteString(fmt.Sprintf("--%s--\r\n", boundary)) + ior = bf + return +} + +func ReadJSON(r *http.Request, o interface{}) error { + d, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + if err = json.Unmarshal(d, o); err != nil { + return err + } + return nil +} + +func JSON(w http.ResponseWriter, v interface{}) { + d, err := json.Marshal(v) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + w.Header().Set("content-type", "application/json") + w.Write(d) +} diff --git a/vendor/github.com/txthinking/x/init.go b/vendor/github.com/txthinking/x/init.go new file mode 100644 index 00000000..32e6e2c6 --- /dev/null +++ b/vendor/github.com/txthinking/x/init.go @@ -0,0 +1,10 @@ +package x + +// import _ "net/http/pprof" + +func init() { + // log.SetFlags(log.LstdFlags | log.Lshortfile) + // go func() { + // log.Println(http.ListenAndServe(":6060", nil)) + // }() +} diff --git a/vendor/github.com/txthinking/x/ip.go b/vendor/github.com/txthinking/x/ip.go new file mode 100644 index 00000000..079892fc --- /dev/null +++ b/vendor/github.com/txthinking/x/ip.go @@ -0,0 +1,181 @@ +package x + +import ( + "encoding/binary" + "errors" + "net" + "strconv" + "strings" +) + +// IP2Decimal0 transform ip format like x.x.x.x to decimal. +// ref: https://zh.wikipedia.org/wiki/IPv4 +func IP2Decimal0(ip string) (n int64, err error) { + ss := strings.Split(ip, ".") + var b string + var s string + var i int64 + if len(ss) != 4 { + err = errors.New("IP Invalid") + return + } + for _, s = range ss { + i, err = strconv.ParseInt(s, 10, 64) + if err != nil { + return + } + s = strconv.FormatInt(i, 2) + var j int + need := 8 - len(s) + for j = 0; j < need; j++ { + s = "0" + s + } + b += s + } + n, _ = strconv.ParseInt(b, 2, 64) + return +} + +// IP2Decimal1 transform ip format like x.x.x.x to decimal. +func IP2Decimal1(ipstr string) (int64, error) { + ip := net.ParseIP(ipstr) + if ip == nil { + return 0, errors.New("ParseIP error") + } + // ip is 16 bytes, but ipv4 is only in last 4 bytes + d := uint32(ip[12])<<24 | uint32(ip[13])<<16 | uint32(ip[14])<<8 | uint32(ip[15]) + return int64(d), nil +} + +// IP2Decimal transform ip format like x.x.x.x to decimal. +func IP2Decimal(ipstr string) (int64, error) { + ip := net.ParseIP(ipstr) + if ip == nil { + return 0, errors.New("ParseIP error") + } + // ip is 16 bytes, but ipv4 is only in last 4 bytes + d := binary.BigEndian.Uint32(ip[12:16]) + return int64(d), nil +} + +// Decimal2IP0 transform a decimal IP to x.x.x.x format. +// ref: https://zh.wikipedia.org/wiki/IPv4 +func Decimal2IP0(n int64) (ip string, err error) { + ips := make([]string, 4) + var b string + var i int64 + b = strconv.FormatInt(n, 2) + need := 32 - len(b) + var j int + for j = 0; j < need; j++ { + b = "0" + b + } + i, _ = strconv.ParseInt(b[0:8], 2, 64) + ips[0] = strconv.FormatInt(i, 10) + i, _ = strconv.ParseInt(b[8:16], 2, 64) + ips[1] = strconv.FormatInt(i, 10) + i, _ = strconv.ParseInt(b[16:24], 2, 64) + ips[2] = strconv.FormatInt(i, 10) + i, _ = strconv.ParseInt(b[24:32], 2, 64) + ips[3] = strconv.FormatInt(i, 10) + ip = strings.Join(ips, ".") + return +} + +// Decimal2IP1 transform a decimal IP to x.x.x.x format. +func Decimal2IP1(n int64) string { + ui := uint32(n) + ip := "" + for i := 0; i < 4; i++ { + offset := 8 * (3 - i) + tmp := (ui >> uint32(offset)) & 0xff + if ip != "" { + ip += "." + } + ip += strconv.Itoa(int(tmp)) + } + return ip +} + +// Decimal2IP transform a decimal IP to x.x.x.x format. +func Decimal2IP(n int64) string { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, uint32(n)) + return ip.String() +} + +// CIDRInfo is the struct of CIDR +type CIDRInfo struct { + First string + Last string + Block int64 + Network string + Count int64 +} + +// CIDR return *CIDRInfo from like this x.x.x.x/x +// ref: http://goo.gl/AEUIi8 +func CIDR(cidr string) (c *CIDRInfo, err error) { + c = new(CIDRInfo) + cs := strings.Split(cidr, "/") + if len(cs) != 2 { + err = errors.New("CIDR Invalid") + return + } + var ipd int64 + ipd, err = IP2Decimal(cs[0]) + if err != nil { + return + } + var ipb string + ipb = strconv.FormatInt(ipd, 2) + need := 32 - len(ipb) + var j int + for j = 0; j < need; j++ { + ipb = "0" + ipb + } + + var n int64 + n, err = strconv.ParseInt(cs[1], 10, 64) + if err != nil { + return + } + if n < 0 || n > 32 { + err = errors.New("CIDR Invalid") + return + } + c.Block = n + + var network string + var networkI int64 + for j = 0; j < int(n); j++ { + network += "1" + } + for j = 0; j < 32-int(n); j++ { + network += "0" + } + networkI, _ = strconv.ParseInt(network, 2, 64) + network = Decimal2IP(networkI) + c.Network = network + + first := ipb[0:n] + var firstI int64 + for j = 0; j < 32-int(n); j++ { + first = first + "0" + } + firstI, _ = strconv.ParseInt(first, 2, 64) + first = Decimal2IP(firstI) + c.First = first + + last := ipb[0:n] + var lastI int64 + for j = 0; j < 32-int(n); j++ { + last = last + "1" + } + lastI, _ = strconv.ParseInt(last, 2, 64) + last = Decimal2IP(lastI) + c.Last = last + + c.Count = lastI - firstI + 1 + return +} diff --git a/vendor/github.com/txthinking/x/ip_test.go b/vendor/github.com/txthinking/x/ip_test.go new file mode 100644 index 00000000..af0a329c --- /dev/null +++ b/vendor/github.com/txthinking/x/ip_test.go @@ -0,0 +1,49 @@ +package x + +import ( + "testing" +) + +func TestIP2Decimal(t *testing.T) { + r, err := IP2Decimal0("192.168.1.10") + if err != nil { + t.Fatal(r, err) + } + t.Log(r) + r, err = IP2Decimal1("192.168.1.10") + if err != nil { + t.Fatal(r, err) + } + t.Log(r) + r, err = IP2Decimal("192.168.1.10") + if err != nil { + t.Fatal(r, err) + } + t.Log(r) +} + +func TestDecimal2IP(t *testing.T) { + r, err := Decimal2IP0(123423434) + if err != nil { + t.Fatal(r, err) + } + t.Log(r) + r = Decimal2IP1(123423434) + if err != nil { + t.Fatal(r, err) + } + t.Log(r) + r = Decimal2IP(123423434) + if err != nil { + t.Fatal(r, err) + } + t.Log(r) +} + +func TestCIDR(t *testing.T) { + r, err := CIDR("192.168.1.10/6") + if err != nil { + t.Fatal(r, err) + } + t.Log(r) +} diff --git a/vendor/github.com/txthinking/x/is.go b/vendor/github.com/txthinking/x/is.go new file mode 100644 index 00000000..749f4c89 --- /dev/null +++ b/vendor/github.com/txthinking/x/is.go @@ -0,0 +1,85 @@ +package x + +import ( + "math" + "regexp" + "strconv" + "strings" + "unicode" +) + +// IsEmail determine whether it is email address +func IsEmail(email string) (ok bool, err error) { + p := `^\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*$` + ok, err = regexp.MatchString(p, email) + return +} + +// IsBankCard determine whether it is bankcard number +func IsBankCard(n int64) (ok bool, err error) { + s := strconv.FormatInt(n, 10) + var sum int + var i int + for i = 1; i < len(s); i++ { + var now int + now, _ = strconv.Atoi(string(s[len(s)-1-i])) + if i%2 == 0 { + sum += now + continue + } + var _i int + _i = now * 2 + sum += _i / 10 + sum += _i % 10 + } + var v int + v, _ = strconv.Atoi(string(s[len(s)-1])) + if (sum+v)%10 == 0 { + ok = true + } + return +} + +// IsChineseID determine whether it is Chinese ID Card Number +func IsChineseID(s string) (ok bool, err error) { + if len(s) != 18 { + return + } + var sum int + var i int + for i = 1; i < len(s); i++ { + var now int + now, err = strconv.Atoi(string(s[len(s)-1-i])) + if err != nil { + return + } + var w int + w = int(math.Pow(2, float64(i+1-1))) % 11 + sum += now * w + } + v := (12 - (sum % 11)) % 11 + if v == 10 { + if strings.ToLower(string(s[len(s)-1])) != "x" { + return + } + ok = true + return + } + if string(s[len(s)-1]) != strconv.Itoa(v) { + return + } + ok = true + return +} + +// IsChineseWords determine whether it is Chinese words +// Notice: NOT ALL +func IsChineseWords(words string) (ok bool, err error) { + // every rune is chinese + for _, c := range words { + if !unicode.Is(unicode.Scripts["Han"], c) { + return false, nil + } + } + return true, nil +} diff --git a/vendor/github.com/txthinking/x/is_test.go b/vendor/github.com/txthinking/x/is_test.go new file mode 100644 index 00000000..c6a989d5 --- /dev/null +++ b/vendor/github.com/txthinking/x/is_test.go @@ -0,0 +1,74 @@ +package x + +import ( + "testing" +) + +func TestIsEmail(t *testing.T) { + es := []string{ + "error@mail", + "correct@mail.com", + "_hi@mail.com", + "aa_hi@qq.com", + "a$_hi@qq.com", + "!@@@1.com", + "!~@@1.cm", + "a@1.cm", + } + for _, v := range es { + ok, err := IsEmail(v) + if err != nil { + t.Fatal(v, err) + } + t.Log(v, ok) + } +} + +func TestIsBankCard(t *testing.T) { + a := []int64{ + 4512893900582108, + 6228480010323650910, + 6228480010323650919, // error + } + for _, v := range a { + ok, err := IsBankCard(v) + if err != nil { + t.Fatal(v, err) + } + t.Log(v, ok) + } +} + +func TestIsChineseID(t *testing.T) { + a := []string{ + "61052819890402574X", + "411081198804220861", + "411081198804220851", + } + for _, v := range a { + ok, err := IsChineseID(v) + if err != nil { + t.Fatal(v, err) + } + t.Log(v, ok) + } +} + +func TestIsChineseWords(t *testing.T) { + a := []struct { + input string + expected bool + }{ + {"猪八戒", true}, + {"xia往往", false}, + } + for _, v := range a { + ok, err := IsChineseWords(v.input) + if err != nil { + t.Fatal(v, err) + } + if ok != v.expected { + t.Fatal("Chinese word test fail") + } + } +} diff --git a/vendor/github.com/txthinking/x/pool.go b/vendor/github.com/txthinking/x/pool.go new file mode 100644 index 00000000..7364ad7c --- /dev/null +++ b/vendor/github.com/txthinking/x/pool.go @@ -0,0 +1,21 @@ +package x + +import "sync" + +func NewBytesPool(n int) sync.Pool { + return sync.Pool{ + New: func() interface{} { + return make([]byte, n) + }, + } +} + +var BP65507 = NewBytesPool(65507) +var BP2048 = NewBytesPool(2048) +var BP40 = NewBytesPool(40) +var BP32 = NewBytesPool(32) +var BP20 = NewBytesPool(20) +var BP16 = NewBytesPool(16) +var BP12 = NewBytesPool(12) +var BP4 = NewBytesPool(4) +var BP2 = NewBytesPool(2) diff --git a/vendor/github.com/txthinking/x/random.go b/vendor/github.com/txthinking/x/random.go new file mode 100644 index 00000000..38cc1c4a --- /dev/null +++ b/vendor/github.com/txthinking/x/random.go @@ -0,0 +1,21 @@ +package x + +import ( + "math/rand" + "time" +) + +// RandomNumber used to get a random number +func RandomNumber() (i int64) { + i = rand.New(rand.NewSource(time.Now().UnixNano())).Int63() + return +} + +// Random used to get a random number between [min, max) +func Random(min, max int64) int64 { + if max <= min { + return min + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return r.Int63n(max-min) + min +} diff --git a/vendor/github.com/txthinking/x/random_test.go b/vendor/github.com/txthinking/x/random_test.go new file mode 100644 index 00000000..7a754032 --- /dev/null +++ b/vendor/github.com/txthinking/x/random_test.go @@ -0,0 +1,13 @@ +package x + +import ( + "testing" +) + +func TestRandomNumber(t *testing.T) { + t.Log(RandomNumber()) + t.Log(RandomNumber()) + t.Log(RandomNumber()) + t.Log(Random(1000, 9999)) + t.Log(Random(1000, 9999)) +} diff --git a/vendor/github.com/txthinking/x/test_test.go b/vendor/github.com/txthinking/x/test_test.go new file mode 100644 index 00000000..4f1af9fb --- /dev/null +++ b/vendor/github.com/txthinking/x/test_test.go @@ -0,0 +1,6 @@ +package x + +import "testing" + +func TestTest(t *testing.T) { +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/.gitignore b/vendor/github.com/xtaci/kcp-go/v5/.gitignore new file mode 100644 index 00000000..2f4178cc --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/.gitignore @@ -0,0 +1,25 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test +/vendor/ + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/xtaci/kcp-go/v5/.travis.yml b/vendor/github.com/xtaci/kcp-go/v5/.travis.yml new file mode 100644 index 00000000..6754ef67 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/.travis.yml @@ -0,0 +1,24 @@ +arch: + - amd64 + - ppc64le +language: go + +go: + - 1.11.x + - 1.12.x + - 1.13.x + +env: + - GO111MODULE=on + +before_install: + - go get -t -v ./... + +install: + - go get github.com/xtaci/kcp-go + +script: + - go test -coverprofile=coverage.txt -covermode=atomic -bench . -timeout 10m + +after_success: + - bash <(curl -s https://codecov.io/bash) diff --git a/vendor/github.com/xtaci/kcp-go/v5/LICENSE b/vendor/github.com/xtaci/kcp-go/v5/LICENSE new file mode 100644 index 00000000..8294d134 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2015 Daniel Fu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/vendor/github.com/xtaci/kcp-go/v5/README.md b/vendor/github.com/xtaci/kcp-go/v5/README.md new file mode 100644 index 00000000..f68406d9 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/README.md @@ -0,0 +1,279 @@ +kcp-go + + +[![GoDoc][1]][2] [![Powered][9]][10] [![MIT licensed][11]][12] [![Build Status][3]][4] [![Go Report Card][5]][6] [![Coverage Statusd][7]][8] [![Sourcegraph][13]][14] + +[1]: https://godoc.org/github.com/xtaci/kcp-go?status.svg +[2]: https://pkg.go.dev/github.com/xtaci/kcp-go +[3]: https://travis-ci.org/xtaci/kcp-go.svg?branch=master +[4]: https://travis-ci.org/xtaci/kcp-go +[5]: https://goreportcard.com/badge/github.com/xtaci/kcp-go +[6]: https://goreportcard.com/report/github.com/xtaci/kcp-go +[7]: https://codecov.io/gh/xtaci/kcp-go/branch/master/graph/badge.svg +[8]: https://codecov.io/gh/xtaci/kcp-go +[9]: https://img.shields.io/badge/KCP-Powered-blue.svg +[10]: https://github.com/skywind3000/kcp +[11]: https://img.shields.io/badge/license-MIT-blue.svg +[12]: LICENSE +[13]: https://sourcegraph.com/github.com/xtaci/kcp-go/-/badge.svg +[14]: https://sourcegraph.com/github.com/xtaci/kcp-go?badge + +## Introduction + +**kcp-go** is a **Production-Grade Reliable-UDP** library for [golang](https://golang.org/). + +This library intents to provide a **smooth, resilient, ordered, error-checked and anonymous** delivery of streams over **UDP** packets, it has been battle-tested with opensource project [kcptun](https://github.com/xtaci/kcptun). Millions of devices(from low-end MIPS routers to high-end servers) have deployed **kcp-go** powered program in a variety of forms like **online games, live broadcasting, file synchronization and network acceleration**. + +[Lastest Release](https://github.com/xtaci/kcp-go/releases) + +## Features + +1. Designed for **Latency-sensitive** scenarios. +1. **Cache friendly** and **Memory optimized** design, offers extremely **High Performance** core. +1. Handles **>5K concurrent connections** on a single commodity server. +1. Compatible with [net.Conn](https://golang.org/pkg/net/#Conn) and [net.Listener](https://golang.org/pkg/net/#Listener), a drop-in replacement for [net.TCPConn](https://golang.org/pkg/net/#TCPConn). +1. [FEC(Forward Error Correction)](https://en.wikipedia.org/wiki/Forward_error_correction) Support with [Reed-Solomon Codes](https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction) +1. Packet level encryption support with [AES](https://en.wikipedia.org/wiki/Advanced_Encryption_Standard), [TEA](https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm), [3DES](https://en.wikipedia.org/wiki/Triple_DES), [Blowfish](https://en.wikipedia.org/wiki/Blowfish_(cipher)), [Cast5](https://en.wikipedia.org/wiki/CAST-128), [Salsa20]( https://en.wikipedia.org/wiki/Salsa20), etc. in [CFB](https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Cipher_Feedback_.28CFB.29) mode, which generates completely anonymous packet. +1. Only **A fixed number of goroutines** will be created for the entire server application, costs in **context switch** between goroutines have been taken into consideration. +1. Compatible with [skywind3000's](https://github.com/skywind3000) C version with various improvements. +1. Platform-dependent optimizations: [sendmmsg](http://man7.org/linux/man-pages/man2/sendmmsg.2.html) and [recvmmsg](http://man7.org/linux/man-pages/man2/recvmmsg.2.html) were expoloited for linux. + +## Documentation + +For complete documentation, see the associated [Godoc](https://godoc.org/github.com/xtaci/kcp-go). + +## Specification + +Frame Format + +``` +NONCE: + 16bytes cryptographically secure random number, nonce changes for every packet. + +CRC32: + CRC-32 checksum of data using the IEEE polynomial + +FEC TYPE: + typeData = 0xF1 + typeParity = 0xF2 + +FEC SEQID: + monotonically increasing in range: [0, (0xffffffff/shardSize) * shardSize - 1] + +SIZE: + The size of KCP frame plus 2 +``` + +``` ++-----------------+ +| SESSION | ++-----------------+ +| KCP(ARQ) | ++-----------------+ +| FEC(OPTIONAL) | ++-----------------+ +| CRYPTO(OPTIONAL)| ++-----------------+ +| UDP(PACKET) | ++-----------------+ +| IP | ++-----------------+ +| LINK | ++-----------------+ +| PHY | ++-----------------+ +(LAYER MODEL OF KCP-GO) +``` + + +## Examples + +1. [simple examples](https://github.com/xtaci/kcp-go/tree/master/examples) +2. [kcptun client](https://github.com/xtaci/kcptun/blob/master/client/main.go) +3. [kcptun server](https://github.com/xtaci/kcptun/blob/master/server/main.go) + +## Benchmark +``` +=== +Model Name: MacBook Pro +Model Identifier: MacBookPro14,1 +Processor Name: Intel Core i5 +Processor Speed: 3.1 GHz +Number of Processors: 1 +Total Number of Cores: 2 +L2 Cache (per Core): 256 KB +L3 Cache: 4 MB +Memory: 8 GB +=== + +$ go test -v -run=^$ -bench . +beginning tests, encryption:salsa20, fec:10/3 +goos: darwin +goarch: amd64 +pkg: github.com/xtaci/kcp-go +BenchmarkSM4-4 50000 32180 ns/op 93.23 MB/s 0 B/op 0 allocs/op +BenchmarkAES128-4 500000 3285 ns/op 913.21 MB/s 0 B/op 0 allocs/op +BenchmarkAES192-4 300000 3623 ns/op 827.85 MB/s 0 B/op 0 allocs/op +BenchmarkAES256-4 300000 3874 ns/op 774.20 MB/s 0 B/op 0 allocs/op +BenchmarkTEA-4 100000 15384 ns/op 195.00 MB/s 0 B/op 0 allocs/op +BenchmarkXOR-4 20000000 89.9 ns/op 33372.00 MB/s 0 B/op 0 allocs/op +BenchmarkBlowfish-4 50000 26927 ns/op 111.41 MB/s 0 B/op 0 allocs/op +BenchmarkNone-4 30000000 45.7 ns/op 65597.94 MB/s 0 B/op 0 allocs/op +BenchmarkCast5-4 50000 34258 ns/op 87.57 MB/s 0 B/op 0 allocs/op +Benchmark3DES-4 10000 117149 ns/op 25.61 MB/s 0 B/op 0 allocs/op +BenchmarkTwofish-4 50000 33538 ns/op 89.45 MB/s 0 B/op 0 allocs/op +BenchmarkXTEA-4 30000 45666 ns/op 65.69 MB/s 0 B/op 0 allocs/op +BenchmarkSalsa20-4 500000 3308 ns/op 906.76 MB/s 0 B/op 0 allocs/op +BenchmarkCRC32-4 20000000 65.2 ns/op 15712.43 MB/s +BenchmarkCsprngSystem-4 1000000 1150 ns/op 13.91 MB/s +BenchmarkCsprngMD5-4 10000000 145 ns/op 110.26 MB/s +BenchmarkCsprngSHA1-4 10000000 158 ns/op 126.54 MB/s +BenchmarkCsprngNonceMD5-4 10000000 153 ns/op 104.22 MB/s +BenchmarkCsprngNonceAES128-4 100000000 19.1 ns/op 837.81 MB/s +BenchmarkFECDecode-4 1000000 1119 ns/op 1339.61 MB/s 1606 B/op 2 allocs/op +BenchmarkFECEncode-4 2000000 832 ns/op 1801.83 MB/s 17 B/op 0 allocs/op +BenchmarkFlush-4 5000000 272 ns/op 0 B/op 0 allocs/op +BenchmarkEchoSpeed4K-4 5000 259617 ns/op 15.78 MB/s 5451 B/op 149 allocs/op +BenchmarkEchoSpeed64K-4 1000 1706084 ns/op 38.41 MB/s 56002 B/op 1604 allocs/op +BenchmarkEchoSpeed512K-4 100 14345505 ns/op 36.55 MB/s 482597 B/op 13045 allocs/op +BenchmarkEchoSpeed1M-4 30 34859104 ns/op 30.08 MB/s 1143773 B/op 27186 allocs/op +BenchmarkSinkSpeed4K-4 50000 31369 ns/op 130.57 MB/s 1566 B/op 30 allocs/op +BenchmarkSinkSpeed64K-4 5000 329065 ns/op 199.16 MB/s 21529 B/op 453 allocs/op +BenchmarkSinkSpeed256K-4 500 2373354 ns/op 220.91 MB/s 166332 B/op 3554 allocs/op +BenchmarkSinkSpeed1M-4 300 5117927 ns/op 204.88 MB/s 310378 B/op 6988 allocs/op +PASS +ok github.com/xtaci/kcp-go 50.349s +``` + +``` +=== Raspberry Pi 4 === + +➜ kcp-go git:(master) cat /proc/cpuinfo +processor : 0 +model name : ARMv7 Processor rev 3 (v7l) +BogoMIPS : 108.00 +Features : half thumb fastmult vfp edsp neon vfpv3 tls vfpv4 idiva idivt vfpd32 lpae evtstrm crc32 +CPU implementer : 0x41 +CPU architecture: 7 +CPU variant : 0x0 +CPU part : 0xd08 +CPU revision : 3 + +➜ kcp-go git:(master) go test -run=^$ -bench . +2020/01/05 19:25:13 beginning tests, encryption:salsa20, fec:10/3 +goos: linux +goarch: arm +pkg: github.com/xtaci/kcp-go/v5 +BenchmarkSM4-4 20000 86475 ns/op 34.69 MB/s 0 B/op 0 allocs/op +BenchmarkAES128-4 20000 62254 ns/op 48.19 MB/s 0 B/op 0 allocs/op +BenchmarkAES192-4 20000 71802 ns/op 41.78 MB/s 0 B/op 0 allocs/op +BenchmarkAES256-4 20000 80570 ns/op 37.23 MB/s 0 B/op 0 allocs/op +BenchmarkTEA-4 50000 37343 ns/op 80.34 MB/s 0 B/op 0 allocs/op +BenchmarkXOR-4 100000 22266 ns/op 134.73 MB/s 0 B/op 0 allocs/op +BenchmarkBlowfish-4 20000 66123 ns/op 45.37 MB/s 0 B/op 0 allocs/op +BenchmarkNone-4 3000000 518 ns/op 5786.77 MB/s 0 B/op 0 allocs/op +BenchmarkCast5-4 20000 76705 ns/op 39.11 MB/s 0 B/op 0 allocs/op +Benchmark3DES-4 5000 418868 ns/op 7.16 MB/s 0 B/op 0 allocs/op +BenchmarkTwofish-4 5000 326896 ns/op 9.18 MB/s 0 B/op 0 allocs/op +BenchmarkXTEA-4 10000 114418 ns/op 26.22 MB/s 0 B/op 0 allocs/op +BenchmarkSalsa20-4 50000 36736 ns/op 81.66 MB/s 0 B/op 0 allocs/op +BenchmarkCRC32-4 1000000 1735 ns/op 589.98 MB/s +BenchmarkCsprngSystem-4 1000000 2179 ns/op 7.34 MB/s +BenchmarkCsprngMD5-4 2000000 811 ns/op 19.71 MB/s +BenchmarkCsprngSHA1-4 2000000 862 ns/op 23.19 MB/s +BenchmarkCsprngNonceMD5-4 2000000 878 ns/op 18.22 MB/s +BenchmarkCsprngNonceAES128-4 5000000 326 ns/op 48.97 MB/s +BenchmarkFECDecode-4 200000 9081 ns/op 165.16 MB/s 140 B/op 1 allocs/op +BenchmarkFECEncode-4 100000 12039 ns/op 124.59 MB/s 11 B/op 0 allocs/op +BenchmarkFlush-4 100000 21704 ns/op 0 B/op 0 allocs/op +BenchmarkEchoSpeed4K-4 2000 981182 ns/op 4.17 MB/s 12384 B/op 424 allocs/op +BenchmarkEchoSpeed64K-4 100 10503324 ns/op 6.24 MB/s 123616 B/op 3779 allocs/op +BenchmarkEchoSpeed512K-4 20 138633802 ns/op 3.78 MB/s 1606584 B/op 29233 allocs/op +BenchmarkEchoSpeed1M-4 5 372903568 ns/op 2.81 MB/s 4080504 B/op 63600 allocs/op +BenchmarkSinkSpeed4K-4 10000 121239 ns/op 33.78 MB/s 4647 B/op 104 allocs/op +BenchmarkSinkSpeed64K-4 1000 1587906 ns/op 41.27 MB/s 50914 B/op 1115 allocs/op +BenchmarkSinkSpeed256K-4 100 16277830 ns/op 32.21 MB/s 453027 B/op 9296 allocs/op +BenchmarkSinkSpeed1M-4 100 31040703 ns/op 33.78 MB/s 898097 B/op 18932 allocs/op +PASS +ok github.com/xtaci/kcp-go/v5 64.151s +``` + + +## Typical Flame Graph +![Flame Graph in kcptun](flame.png) + +## Key Design Considerations + +1. slice vs. container/list + +`kcp.flush()` loops through the send queue for retransmission checking for every 20ms(interval). + +I've wrote a benchmark for comparing sequential loop through *slice* and *container/list* here: + +https://github.com/xtaci/notes/blob/master/golang/benchmark2/cachemiss_test.go + +``` +BenchmarkLoopSlice-4 2000000000 0.39 ns/op +BenchmarkLoopList-4 100000000 54.6 ns/op +``` + +List structure introduces **heavy cache misses** compared to slice which owns better **locality**, 5000 connections with 32 window size and 20ms interval will cost 6us/0.03%(cpu) using slice, and 8.7ms/43.5%(cpu) for list for each `kcp.flush()`. + +2. Timing accuracy vs. syscall clock_gettime + +Timing is **critical** to **RTT estimator**, inaccurate timing leads to false retransmissions in KCP, but calling `time.Now()` costs 42 cycles(10.5ns on 4GHz CPU, 15.6ns on my MacBook Pro 2.7GHz). + +The benchmark for time.Now() lies here: + +https://github.com/xtaci/notes/blob/master/golang/benchmark2/syscall_test.go + +``` +BenchmarkNow-4 100000000 15.6 ns/op +``` + +In kcp-go, after each `kcp.output()` function call, current clock time will be updated upon return, and for a single `kcp.flush()` operation, current time will be queried from system once. For most of the time, 5000 connections costs 5000 * 15.6ns = 78us(a fixed cost while no packet needs to be sent), as for 10MB/s data transfering with 1400 MTU, `kcp.output()` will be called around 7500 times and costs 117us for `time.Now()` in **every second**. + +3. Memory management + +Primary memory allocation are done from a global buffer pool xmit.Buf, in kcp-go, when we need to allocate some bytes, we can get from that pool, and a fixed-capacity 1500 bytes(mtuLimit) will be returned, the rx queue, tx queue and fec queue all receive bytes from there, and they will return the bytes to the pool after using to prevent unnecessary zer0ing of bytes. The pool mechanism maintained a high watermark for slice objects, these in-flight objects from the pool will survive from the perodical garbage collection, meanwhile the pool kept the ability to return the memory to runtime if in idle. + +4. Information security + +kcp-go is shipped with builtin packet encryption powered by various block encryption algorithms and works in [Cipher Feedback Mode](https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Cipher_Feedback_(CFB)), for each packet to be sent, the encryption process will start from encrypting a [nonce](https://en.wikipedia.org/wiki/Cryptographic_nonce) from the [system entropy](https://en.wikipedia.org/wiki//dev/random), so encryption to same plaintexts never leads to a same ciphertexts thereafter. + +The contents of the packets are completely anonymous with encryption, including the headers(FEC,KCP), checksums and contents. Note that, no matter which encryption method you choose on you upper layer, if you disable encryption, the transmit will be insecure somehow, since the header is ***PLAINTEXT*** to everyone it would be susceptible to header tampering, such as jamming the *sliding window size*, *round-trip time*, *FEC property* and *checksums*. ```AES-128``` is suggested for minimal encryption since modern CPUs are shipped with [AES-NI](https://en.wikipedia.org/wiki/AES_instruction_set) instructions and performs even better than `salsa20`(check the table above). + +Other possible attacks to kcp-go includes: a) [traffic analysis](https://en.wikipedia.org/wiki/Traffic_analysis), dataflow on specific websites may have pattern while interchanging data, but this type of eavesdropping has been mitigated by adapting [smux](https://github.com/xtaci/smux) to mix data streams so as to introduce noises, perfect solution to this has not appeared yet, theroretically by shuffling/mixing messages on larger scale network may mitigate this problem. b) [replay attack](https://en.wikipedia.org/wiki/Replay_attack), since the asymmetrical encryption has not been introduced into kcp-go for some reason, capturing the packets and replay them on a different machine is possible, (notice: hijacking the session and decrypting the contents is still *impossible*), so upper layers should contain a asymmetrical encryption system to guarantee the authenticity of each message(to process message exactly once), such as HTTPS/OpenSSL/LibreSSL, only by signing the requests with private keys can eliminate this type of attack. + +## Connection Termination + +Control messages like **SYN/FIN/RST** in TCP **are not defined** in KCP, you need some **keepalive/heartbeat mechanism** in the application-level. A real world example is to use some **multiplexing** protocol over session, such as [smux](https://github.com/xtaci/smux)(with embedded keepalive mechanism), see [kcptun](https://github.com/xtaci/kcptun) for example. + +## FAQ + +Q: I'm handling >5K connections on my server, the CPU utilization is so high. + +A: A standalone `agent` or `gate` server for running kcp-go is suggested, not only for CPU utilization, but also important to the **precision** of RTT measurements(timing) which indirectly affects retransmission. By increasing update `interval` with `SetNoDelay` like `conn.SetNoDelay(1, 40, 1, 1)` will dramatically reduce system load, but lower the performance. + +Q: When should I enable FEC? + +A: Forward error correction is critical to long-distance transmission, because a packet loss will lead to a huge penalty in time. And for the complicated packet routing network in modern world, round-trip time based loss check will not always be efficient, the big deviation of RTT samples in the long way usually leads to a larger RTO value in typical rtt estimator, which in other words, slows down the transmission. + +Q: Should I enable encryption? + +A: Yes, for the safety of protocol, even if the upper layer has encrypted. + +## Who is using this? + +1. https://github.com/xtaci/kcptun -- A Secure Tunnel Based On KCP over UDP. +2. https://github.com/getlantern/lantern -- Lantern delivers fast access to the open Internet. +3. https://github.com/smallnest/rpcx -- A RPC service framework based on net/rpc like alibaba Dubbo and weibo Motan. +4. https://github.com/gonet2/agent -- A gateway for games with stream multiplexing. +5. https://github.com/syncthing/syncthing -- Open Source Continuous File Synchronization. + +## Links + +1. https://github.com/xtaci/smux/ -- A Stream Multiplexing Library for golang with least memory +1. https://github.com/xtaci/libkcp -- FEC enhanced KCP session library for iOS/Android in C++ +1. https://github.com/skywind3000/kcp -- A Fast and Reliable ARQ Protocol +1. https://github.com/klauspost/reedsolomon -- Reed-Solomon Erasure Coding in Go diff --git a/vendor/github.com/xtaci/kcp-go/v5/autotune.go b/vendor/github.com/xtaci/kcp-go/v5/autotune.go new file mode 100644 index 00000000..1f85be33 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/autotune.go @@ -0,0 +1,64 @@ +package kcp + +const maxAutoTuneSamples = 258 + +// pulse represents a 0/1 signal with time sequence +type pulse struct { + bit bool // 0 or 1 + seq uint32 // sequence of the signal +} + +// autoTune object +type autoTune struct { + pulses [maxAutoTuneSamples]pulse +} + +// Sample adds a signal sample to the pulse buffer +func (tune *autoTune) Sample(bit bool, seq uint32) { + tune.pulses[seq%maxAutoTuneSamples] = pulse{bit, seq} +} + +// Find a period for a given signal +// returns -1 if not found +// +// --- ------ +// | | +// |______________| +// Period +// Falling Edge Rising Edge +func (tune *autoTune) FindPeriod(bit bool) int { + // last pulse and initial index setup + lastPulse := tune.pulses[0] + idx := 1 + + // left edge + var leftEdge int + for ; idx < len(tune.pulses); idx++ { + if lastPulse.bit != bit && tune.pulses[idx].bit == bit { // edge found + if lastPulse.seq+1 == tune.pulses[idx].seq { // ensure edge continuity + leftEdge = idx + break + } + } + lastPulse = tune.pulses[idx] + } + + // right edge + var rightEdge int + lastPulse = tune.pulses[leftEdge] + idx = leftEdge + 1 + + for ; idx < len(tune.pulses); idx++ { + if lastPulse.seq+1 == tune.pulses[idx].seq { // ensure pulses in this level monotonic + if lastPulse.bit == bit && tune.pulses[idx].bit != bit { // edge found + rightEdge = idx + break + } + } else { + return -1 + } + lastPulse = tune.pulses[idx] + } + + return rightEdge - leftEdge +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/autotune_test.go b/vendor/github.com/xtaci/kcp-go/v5/autotune_test.go new file mode 100644 index 00000000..3dc1ecc6 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/autotune_test.go @@ -0,0 +1,47 @@ +package kcp + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAutoTune(t *testing.T) { + signals := []uint32{0, 0, 0, 0, 0, 0} + + tune := autoTune{} + for i := 0; i < len(signals); i++ { + if signals[i] == 0 { + tune.Sample(false, uint32(i)) + } else { + tune.Sample(true, uint32(i)) + } + } + + assert.Equal(t, -1, tune.FindPeriod(false)) + assert.Equal(t, -1, tune.FindPeriod(true)) + + signals = []uint32{1, 0, 1, 0, 0, 1} + tune = autoTune{} + for i := 0; i < len(signals); i++ { + if signals[i] == 0 { + tune.Sample(false, uint32(i)) + } else { + tune.Sample(true, uint32(i)) + } + } + assert.Equal(t, 1, tune.FindPeriod(false)) + assert.Equal(t, 1, tune.FindPeriod(true)) + + signals = []uint32{1, 0, 0, 0, 0, 1} + tune = autoTune{} + for i := 0; i < len(signals); i++ { + if signals[i] == 0 { + tune.Sample(false, uint32(i)) + } else { + tune.Sample(true, uint32(i)) + } + } + assert.Equal(t, -1, tune.FindPeriod(true)) + assert.Equal(t, 4, tune.FindPeriod(false)) +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/batchconn.go b/vendor/github.com/xtaci/kcp-go/v5/batchconn.go new file mode 100644 index 00000000..6c307010 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/batchconn.go @@ -0,0 +1,12 @@ +package kcp + +import "golang.org/x/net/ipv4" + +const ( + batchSize = 16 +) + +type batchConn interface { + WriteBatch(ms []ipv4.Message, flags int) (int, error) + ReadBatch(ms []ipv4.Message, flags int) (int, error) +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/crypt.go b/vendor/github.com/xtaci/kcp-go/v5/crypt.go new file mode 100644 index 00000000..d8828522 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/crypt.go @@ -0,0 +1,618 @@ +package kcp + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/sha1" + "unsafe" + + xor "github.com/templexxx/xorsimd" + "github.com/tjfoc/gmsm/sm4" + + "golang.org/x/crypto/blowfish" + "golang.org/x/crypto/cast5" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/crypto/salsa20" + "golang.org/x/crypto/tea" + "golang.org/x/crypto/twofish" + "golang.org/x/crypto/xtea" +) + +var ( + initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107} + saltxor = `sH3CIVoF#rWLtJo6` +) + +// BlockCrypt defines encryption/decryption methods for a given byte slice. +// Notes on implementing: the data to be encrypted contains a builtin +// nonce at the first 16 bytes +type BlockCrypt interface { + // Encrypt encrypts the whole block in src into dst. + // Dst and src may point at the same memory. + Encrypt(dst, src []byte) + + // Decrypt decrypts the whole block in src into dst. + // Dst and src may point at the same memory. + Decrypt(dst, src []byte) +} + +type salsa20BlockCrypt struct { + key [32]byte +} + +// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20 +func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(salsa20BlockCrypt) + copy(c.key[:], key) + return c, nil +} + +func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) { + salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key) + copy(dst[:8], src[:8]) +} +func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) { + salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key) + copy(dst[:8], src[:8]) +} + +type sm4BlockCrypt struct { + encbuf [sm4.BlockSize]byte // 64bit alignment enc/dec buffer + decbuf [2 * sm4.BlockSize]byte + block cipher.Block +} + +// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4 +func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(sm4BlockCrypt) + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type twofishBlockCrypt struct { + encbuf [twofish.BlockSize]byte + decbuf [2 * twofish.BlockSize]byte + block cipher.Block +} + +// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish +func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(twofishBlockCrypt) + block, err := twofish.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type tripleDESBlockCrypt struct { + encbuf [des.BlockSize]byte + decbuf [2 * des.BlockSize]byte + block cipher.Block +} + +// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES +func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(tripleDESBlockCrypt) + block, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type cast5BlockCrypt struct { + encbuf [cast5.BlockSize]byte + decbuf [2 * cast5.BlockSize]byte + block cipher.Block +} + +// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128 +func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) { + c := new(cast5BlockCrypt) + block, err := cast5.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type blowfishBlockCrypt struct { + encbuf [blowfish.BlockSize]byte + decbuf [2 * blowfish.BlockSize]byte + block cipher.Block +} + +// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher) +func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(blowfishBlockCrypt) + block, err := blowfish.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type aesBlockCrypt struct { + encbuf [aes.BlockSize]byte + decbuf [2 * aes.BlockSize]byte + block cipher.Block +} + +// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard +func NewAESBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(aesBlockCrypt) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type teaBlockCrypt struct { + encbuf [tea.BlockSize]byte + decbuf [2 * tea.BlockSize]byte + block cipher.Block +} + +// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm +func NewTEABlockCrypt(key []byte) (BlockCrypt, error) { + c := new(teaBlockCrypt) + block, err := tea.NewCipherWithRounds(key, 16) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type xteaBlockCrypt struct { + encbuf [xtea.BlockSize]byte + decbuf [2 * xtea.BlockSize]byte + block cipher.Block +} + +// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA +func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) { + c := new(xteaBlockCrypt) + block, err := xtea.NewCipher(key) + if err != nil { + return nil, err + } + c.block = block + return c, nil +} + +func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } +func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } + +type simpleXORBlockCrypt struct { + xortbl []byte +} + +// NewSimpleXORBlockCrypt simple xor with key expanding +func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) { + c := new(simpleXORBlockCrypt) + c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New) + return c, nil +} + +func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) } +func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) } + +type noneBlockCrypt struct{} + +// NewNoneBlockCrypt does nothing but copying +func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) { + return new(noneBlockCrypt), nil +} + +func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) } +func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) } + +// packet encryption with local CFB mode +func encrypt(block cipher.Block, dst, src, buf []byte) { + switch block.BlockSize() { + case 8: + encrypt8(block, dst, src, buf) + case 16: + encrypt16(block, dst, src, buf) + default: + panic("unsupported cipher block size") + } +} + +// optimized encryption for the ciphers which works in 8-bytes +func encrypt8(block cipher.Block, dst, src, buf []byte) { + tbl := buf[:8] + block.Encrypt(tbl, initialVector) + n := len(src) / 8 + base := 0 + repeat := n / 8 + left := n % 8 + ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0])) + + for i := 0; i < repeat; i++ { + s := src[base:][0:64] + d := dst[base:][0:64] + // 1 + *(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl + block.Encrypt(tbl, d[0:8]) + // 2 + *(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_tbl + block.Encrypt(tbl, d[8:16]) + // 3 + *(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl + block.Encrypt(tbl, d[16:24]) + // 4 + *(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_tbl + block.Encrypt(tbl, d[24:32]) + // 5 + *(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl + block.Encrypt(tbl, d[32:40]) + // 6 + *(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_tbl + block.Encrypt(tbl, d[40:48]) + // 7 + *(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl + block.Encrypt(tbl, d[48:56]) + // 8 + *(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_tbl + block.Encrypt(tbl, d[56:64]) + base += 64 + } + + switch left { + case 7: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 6: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 5: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 4: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 3: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 2: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 1: + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl + block.Encrypt(tbl, dst[base:]) + base += 8 + fallthrough + case 0: + xorBytes(dst[base:], src[base:], tbl) + } +} + +// optimized encryption for the ciphers which works in 16-bytes +func encrypt16(block cipher.Block, dst, src, buf []byte) { + tbl := buf[:16] + block.Encrypt(tbl, initialVector) + n := len(src) / 16 + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + s := src[base:][0:128] + d := dst[base:][0:128] + // 1 + xor.Bytes16Align(d[0:16], s[0:16], tbl) + block.Encrypt(tbl, d[0:16]) + // 2 + xor.Bytes16Align(d[16:32], s[16:32], tbl) + block.Encrypt(tbl, d[16:32]) + // 3 + xor.Bytes16Align(d[32:48], s[32:48], tbl) + block.Encrypt(tbl, d[32:48]) + // 4 + xor.Bytes16Align(d[48:64], s[48:64], tbl) + block.Encrypt(tbl, d[48:64]) + // 5 + xor.Bytes16Align(d[64:80], s[64:80], tbl) + block.Encrypt(tbl, d[64:80]) + // 6 + xor.Bytes16Align(d[80:96], s[80:96], tbl) + block.Encrypt(tbl, d[80:96]) + // 7 + xor.Bytes16Align(d[96:112], s[96:112], tbl) + block.Encrypt(tbl, d[96:112]) + // 8 + xor.Bytes16Align(d[112:128], s[112:128], tbl) + block.Encrypt(tbl, d[112:128]) + base += 128 + } + + switch left { + case 7: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 6: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 5: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 4: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 3: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 2: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 1: + xor.Bytes16Align(dst[base:], src[base:], tbl) + block.Encrypt(tbl, dst[base:]) + base += 16 + fallthrough + case 0: + xorBytes(dst[base:], src[base:], tbl) + } +} + +// decryption +func decrypt(block cipher.Block, dst, src, buf []byte) { + switch block.BlockSize() { + case 8: + decrypt8(block, dst, src, buf) + case 16: + decrypt16(block, dst, src, buf) + default: + panic("unsupported cipher block size") + } +} + +// decrypt 8 bytes block, all byte slices are supposed to be 64bit aligned +func decrypt8(block cipher.Block, dst, src, buf []byte) { + tbl := buf[0:8] + next := buf[8:16] + block.Encrypt(tbl, initialVector) + n := len(src) / 8 + base := 0 + repeat := n / 8 + left := n % 8 + ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0])) + ptr_next := (*uint64)(unsafe.Pointer(&next[0])) + + for i := 0; i < repeat; i++ { + s := src[base:][0:64] + d := dst[base:][0:64] + // 1 + block.Encrypt(next, s[0:8]) + *(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl + // 2 + block.Encrypt(tbl, s[8:16]) + *(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_next + // 3 + block.Encrypt(next, s[16:24]) + *(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl + // 4 + block.Encrypt(tbl, s[24:32]) + *(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_next + // 5 + block.Encrypt(next, s[32:40]) + *(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl + // 6 + block.Encrypt(tbl, s[40:48]) + *(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_next + // 7 + block.Encrypt(next, s[48:56]) + *(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl + // 8 + block.Encrypt(tbl, s[56:64]) + *(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_next + base += 64 + } + + switch left { + case 7: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 6: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 5: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 4: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 3: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 2: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 1: + block.Encrypt(next, src[base:]) + *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) + tbl, next = next, tbl + base += 8 + fallthrough + case 0: + xorBytes(dst[base:], src[base:], tbl) + } +} + +func decrypt16(block cipher.Block, dst, src, buf []byte) { + tbl := buf[0:16] + next := buf[16:32] + block.Encrypt(tbl, initialVector) + n := len(src) / 16 + base := 0 + repeat := n / 8 + left := n % 8 + for i := 0; i < repeat; i++ { + s := src[base:][0:128] + d := dst[base:][0:128] + // 1 + block.Encrypt(next, s[0:16]) + xor.Bytes16Align(d[0:16], s[0:16], tbl) + // 2 + block.Encrypt(tbl, s[16:32]) + xor.Bytes16Align(d[16:32], s[16:32], next) + // 3 + block.Encrypt(next, s[32:48]) + xor.Bytes16Align(d[32:48], s[32:48], tbl) + // 4 + block.Encrypt(tbl, s[48:64]) + xor.Bytes16Align(d[48:64], s[48:64], next) + // 5 + block.Encrypt(next, s[64:80]) + xor.Bytes16Align(d[64:80], s[64:80], tbl) + // 6 + block.Encrypt(tbl, s[80:96]) + xor.Bytes16Align(d[80:96], s[80:96], next) + // 7 + block.Encrypt(next, s[96:112]) + xor.Bytes16Align(d[96:112], s[96:112], tbl) + // 8 + block.Encrypt(tbl, s[112:128]) + xor.Bytes16Align(d[112:128], s[112:128], next) + base += 128 + } + + switch left { + case 7: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 6: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 5: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 4: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 3: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 2: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 1: + block.Encrypt(next, src[base:]) + xor.Bytes16Align(dst[base:], src[base:], tbl) + tbl, next = next, tbl + base += 16 + fallthrough + case 0: + xorBytes(dst[base:], src[base:], tbl) + } +} + +// per bytes xors +func xorBytes(dst, a, b []byte) int { + n := len(a) + if len(b) < n { + n = len(b) + } + if n == 0 { + return 0 + } + + for i := 0; i < n; i++ { + dst[i] = a[i] ^ b[i] + } + return n +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/crypt_test.go b/vendor/github.com/xtaci/kcp-go/v5/crypt_test.go new file mode 100644 index 00000000..2ef4dc8a --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/crypt_test.go @@ -0,0 +1,289 @@ +package kcp + +import ( + "bytes" + "crypto/aes" + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "hash/crc32" + "io" + "testing" +) + +func TestSM4(t *testing.T) { + bc, err := NewSM4BlockCrypt(pass[:16]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestAES(t *testing.T) { + bc, err := NewAESBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestTEA(t *testing.T) { + bc, err := NewTEABlockCrypt(pass[:16]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestXOR(t *testing.T) { + bc, err := NewSimpleXORBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestBlowfish(t *testing.T) { + bc, err := NewBlowfishBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestNone(t *testing.T) { + bc, err := NewNoneBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestCast5(t *testing.T) { + bc, err := NewCast5BlockCrypt(pass[:16]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func Test3DES(t *testing.T) { + bc, err := NewTripleDESBlockCrypt(pass[:24]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestTwofish(t *testing.T) { + bc, err := NewTwofishBlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestXTEA(t *testing.T) { + bc, err := NewXTEABlockCrypt(pass[:16]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func TestSalsa20(t *testing.T) { + bc, err := NewSalsa20BlockCrypt(pass[:32]) + if err != nil { + t.Fatal(err) + } + cryptTest(t, bc) +} + +func cryptTest(t *testing.T, bc BlockCrypt) { + data := make([]byte, mtuLimit) + io.ReadFull(rand.Reader, data) + dec := make([]byte, mtuLimit) + enc := make([]byte, mtuLimit) + bc.Encrypt(enc, data) + bc.Decrypt(dec, enc) + if !bytes.Equal(data, dec) { + t.Fail() + } +} + +func BenchmarkSM4(b *testing.B) { + bc, err := NewSM4BlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkAES128(b *testing.B) { + bc, err := NewAESBlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + + benchCrypt(b, bc) +} + +func BenchmarkAES192(b *testing.B) { + bc, err := NewAESBlockCrypt(pass[:24]) + if err != nil { + b.Fatal(err) + } + + benchCrypt(b, bc) +} + +func BenchmarkAES256(b *testing.B) { + bc, err := NewAESBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + + benchCrypt(b, bc) +} + +func BenchmarkTEA(b *testing.B) { + bc, err := NewTEABlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkXOR(b *testing.B) { + bc, err := NewSimpleXORBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkBlowfish(b *testing.B) { + bc, err := NewBlowfishBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkNone(b *testing.B) { + bc, err := NewNoneBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkCast5(b *testing.B) { + bc, err := NewCast5BlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func Benchmark3DES(b *testing.B) { + bc, err := NewTripleDESBlockCrypt(pass[:24]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkTwofish(b *testing.B) { + bc, err := NewTwofishBlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkXTEA(b *testing.B) { + bc, err := NewXTEABlockCrypt(pass[:16]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func BenchmarkSalsa20(b *testing.B) { + bc, err := NewSalsa20BlockCrypt(pass[:32]) + if err != nil { + b.Fatal(err) + } + benchCrypt(b, bc) +} + +func benchCrypt(b *testing.B, bc BlockCrypt) { + data := make([]byte, mtuLimit) + io.ReadFull(rand.Reader, data) + dec := make([]byte, mtuLimit) + enc := make([]byte, mtuLimit) + + b.ReportAllocs() + b.SetBytes(int64(len(enc) * 2)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bc.Encrypt(enc, data) + bc.Decrypt(dec, enc) + } +} + +func BenchmarkCRC32(b *testing.B) { + content := make([]byte, 1024) + b.SetBytes(int64(len(content))) + for i := 0; i < b.N; i++ { + crc32.ChecksumIEEE(content) + } +} + +func BenchmarkCsprngSystem(b *testing.B) { + data := make([]byte, md5.Size) + b.SetBytes(int64(len(data))) + + for i := 0; i < b.N; i++ { + io.ReadFull(rand.Reader, data) + } +} + +func BenchmarkCsprngMD5(b *testing.B) { + var data [md5.Size]byte + b.SetBytes(md5.Size) + + for i := 0; i < b.N; i++ { + data = md5.Sum(data[:]) + } +} +func BenchmarkCsprngSHA1(b *testing.B) { + var data [sha1.Size]byte + b.SetBytes(sha1.Size) + + for i := 0; i < b.N; i++ { + data = sha1.Sum(data[:]) + } +} + +func BenchmarkCsprngNonceMD5(b *testing.B) { + var ng nonceMD5 + ng.Init() + b.SetBytes(md5.Size) + data := make([]byte, md5.Size) + for i := 0; i < b.N; i++ { + ng.Fill(data) + } +} + +func BenchmarkCsprngNonceAES128(b *testing.B) { + var ng nonceAES128 + ng.Init() + + b.SetBytes(aes.BlockSize) + data := make([]byte, aes.BlockSize) + for i := 0; i < b.N; i++ { + ng.Fill(data) + } +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/entropy.go b/vendor/github.com/xtaci/kcp-go/v5/entropy.go new file mode 100644 index 00000000..156c1cd2 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/entropy.go @@ -0,0 +1,52 @@ +package kcp + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/rand" + "io" +) + +// Entropy defines a entropy source +type Entropy interface { + Init() + Fill(nonce []byte) +} + +// nonceMD5 nonce generator for packet header +type nonceMD5 struct { + seed [md5.Size]byte +} + +func (n *nonceMD5) Init() { /*nothing required*/ } + +func (n *nonceMD5) Fill(nonce []byte) { + if n.seed[0] == 0 { // entropy update + io.ReadFull(rand.Reader, n.seed[:]) + } + n.seed = md5.Sum(n.seed[:]) + copy(nonce, n.seed[:]) +} + +// nonceAES128 nonce generator for packet headers +type nonceAES128 struct { + seed [aes.BlockSize]byte + block cipher.Block +} + +func (n *nonceAES128) Init() { + var key [16]byte //aes-128 + io.ReadFull(rand.Reader, key[:]) + io.ReadFull(rand.Reader, n.seed[:]) + block, _ := aes.NewCipher(key[:]) + n.block = block +} + +func (n *nonceAES128) Fill(nonce []byte) { + if n.seed[0] == 0 { // entropy update + io.ReadFull(rand.Reader, n.seed[:]) + } + n.block.Encrypt(n.seed[:], n.seed[:]) + copy(nonce, n.seed[:]) +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/examples/echo.go b/vendor/github.com/xtaci/kcp-go/v5/examples/echo.go new file mode 100644 index 00000000..af73db79 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/examples/echo.go @@ -0,0 +1,77 @@ +package main + +import ( + "crypto/sha1" + "io" + "log" + "time" + + "github.com/xtaci/kcp-go/v5" + "golang.org/x/crypto/pbkdf2" +) + +func main() { + key := pbkdf2.Key([]byte("demo pass"), []byte("demo salt"), 1024, 32, sha1.New) + block, _ := kcp.NewAESBlockCrypt(key) + if listener, err := kcp.ListenWithOptions("127.0.0.1:12345", block, 10, 3); err == nil { + // spin-up the client + go client() + for { + s, err := listener.AcceptKCP() + if err != nil { + log.Fatal(err) + } + go handleEcho(s) + } + } else { + log.Fatal(err) + } +} + +// handleEcho send back everything it received +func handleEcho(conn *kcp.UDPSession) { + buf := make([]byte, 4096) + for { + n, err := conn.Read(buf) + if err != nil { + log.Println(err) + return + } + + n, err = conn.Write(buf[:n]) + if err != nil { + log.Println(err) + return + } + } +} + +func client() { + key := pbkdf2.Key([]byte("demo pass"), []byte("demo salt"), 1024, 32, sha1.New) + block, _ := kcp.NewAESBlockCrypt(key) + + // wait for server to become ready + time.Sleep(time.Second) + + // dial to the echo server + if sess, err := kcp.DialWithOptions("127.0.0.1:12345", block, 10, 3); err == nil { + for { + data := time.Now().String() + buf := make([]byte, len(data)) + log.Println("sent:", data) + if _, err := sess.Write([]byte(data)); err == nil { + // read back the data + if _, err := io.ReadFull(sess, buf); err == nil { + log.Println("recv:", string(buf)) + } else { + log.Fatal(err) + } + } else { + log.Fatal(err) + } + time.Sleep(time.Second) + } + } else { + log.Fatal(err) + } +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/fec.go b/vendor/github.com/xtaci/kcp-go/v5/fec.go new file mode 100644 index 00000000..0a203ee3 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/fec.go @@ -0,0 +1,381 @@ +package kcp + +import ( + "encoding/binary" + "sync/atomic" + + "github.com/klauspost/reedsolomon" +) + +const ( + fecHeaderSize = 6 + fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size + typeData = 0xf1 + typeParity = 0xf2 + fecExpire = 60000 + rxFECMulti = 3 // FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory +) + +// fecPacket is a decoded FEC packet +type fecPacket []byte + +func (bts fecPacket) seqid() uint32 { return binary.LittleEndian.Uint32(bts) } +func (bts fecPacket) flag() uint16 { return binary.LittleEndian.Uint16(bts[4:]) } +func (bts fecPacket) data() []byte { return bts[6:] } + +// fecElement has auxcilliary time field +type fecElement struct { + fecPacket + ts uint32 +} + +// fecDecoder for decoding incoming packets +type fecDecoder struct { + rxlimit int // queue size limit + dataShards int + parityShards int + shardSize int + rx []fecElement // ordered receive queue + + // caches + decodeCache [][]byte + flagCache []bool + + // zeros + zeros []byte + + // RS decoder + codec reedsolomon.Encoder + + // auto tune fec parameter + autoTune autoTune +} + +func newFECDecoder(dataShards, parityShards int) *fecDecoder { + if dataShards <= 0 || parityShards <= 0 { + return nil + } + + dec := new(fecDecoder) + dec.dataShards = dataShards + dec.parityShards = parityShards + dec.shardSize = dataShards + parityShards + dec.rxlimit = rxFECMulti * dec.shardSize + codec, err := reedsolomon.New(dataShards, parityShards) + if err != nil { + return nil + } + dec.codec = codec + dec.decodeCache = make([][]byte, dec.shardSize) + dec.flagCache = make([]bool, dec.shardSize) + dec.zeros = make([]byte, mtuLimit) + return dec +} + +// decode a fec packet +func (dec *fecDecoder) decode(in fecPacket) (recovered [][]byte) { + // sample to auto FEC tuner + if in.flag() == typeData { + dec.autoTune.Sample(true, in.seqid()) + } else { + dec.autoTune.Sample(false, in.seqid()) + } + + // check if FEC parameters is out of sync + var shouldTune bool + if int(in.seqid())%dec.shardSize < dec.dataShards { + if in.flag() != typeData { // expect typeData + shouldTune = true + } + } else { + if in.flag() != typeParity { + shouldTune = true + } + } + + if shouldTune { + autoDS := dec.autoTune.FindPeriod(true) + autoPS := dec.autoTune.FindPeriod(false) + + // edges found, we can tune parameters now + if autoDS > 0 && autoPS > 0 && autoDS < 256 && autoPS < 256 { + // and make sure it's different + if autoDS != dec.dataShards || autoPS != dec.parityShards { + dec.dataShards = autoDS + dec.parityShards = autoPS + dec.shardSize = autoDS + autoPS + dec.rxlimit = rxFECMulti * dec.shardSize + codec, err := reedsolomon.New(autoDS, autoPS) + if err != nil { + return nil + } + dec.codec = codec + dec.decodeCache = make([][]byte, dec.shardSize) + dec.flagCache = make([]bool, dec.shardSize) + //log.Println("autotune to :", dec.dataShards, dec.parityShards) + } + } + } + + // insertion + n := len(dec.rx) - 1 + insertIdx := 0 + for i := n; i >= 0; i-- { + if in.seqid() == dec.rx[i].seqid() { // de-duplicate + return nil + } else if _itimediff(in.seqid(), dec.rx[i].seqid()) > 0 { // insertion + insertIdx = i + 1 + break + } + } + + // make a copy + pkt := fecPacket(xmitBuf.Get().([]byte)[:len(in)]) + copy(pkt, in) + elem := fecElement{pkt, currentMs()} + + // insert into ordered rx queue + if insertIdx == n+1 { + dec.rx = append(dec.rx, elem) + } else { + dec.rx = append(dec.rx, fecElement{}) + copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right + dec.rx[insertIdx] = elem + } + + // shard range for current packet + shardBegin := pkt.seqid() - pkt.seqid()%uint32(dec.shardSize) + shardEnd := shardBegin + uint32(dec.shardSize) - 1 + + // max search range in ordered queue for current shard + searchBegin := insertIdx - int(pkt.seqid()%uint32(dec.shardSize)) + if searchBegin < 0 { + searchBegin = 0 + } + searchEnd := searchBegin + dec.shardSize - 1 + if searchEnd >= len(dec.rx) { + searchEnd = len(dec.rx) - 1 + } + + // re-construct datashards + if searchEnd-searchBegin+1 >= dec.dataShards { + var numshard, numDataShard, first, maxlen int + + // zero caches + shards := dec.decodeCache + shardsflag := dec.flagCache + for k := range dec.decodeCache { + shards[k] = nil + shardsflag[k] = false + } + + // shard assembly + for i := searchBegin; i <= searchEnd; i++ { + seqid := dec.rx[i].seqid() + if _itimediff(seqid, shardEnd) > 0 { + break + } else if _itimediff(seqid, shardBegin) >= 0 { + shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data() + shardsflag[seqid%uint32(dec.shardSize)] = true + numshard++ + if dec.rx[i].flag() == typeData { + numDataShard++ + } + if numshard == 1 { + first = i + } + if len(dec.rx[i].data()) > maxlen { + maxlen = len(dec.rx[i].data()) + } + } + } + + if numDataShard == dec.dataShards { + // case 1: no loss on data shards + dec.rx = dec.freeRange(first, numshard, dec.rx) + } else if numshard >= dec.dataShards { + // case 2: loss on data shards, but it's recoverable from parity shards + for k := range shards { + if shards[k] != nil { + dlen := len(shards[k]) + shards[k] = shards[k][:maxlen] + copy(shards[k][dlen:], dec.zeros) + } else if k < dec.dataShards { + shards[k] = xmitBuf.Get().([]byte)[:0] + } + } + if err := dec.codec.ReconstructData(shards); err == nil { + for k := range shards[:dec.dataShards] { + if !shardsflag[k] { + // recovered data should be recycled + recovered = append(recovered, shards[k]) + } + } + } + dec.rx = dec.freeRange(first, numshard, dec.rx) + } + } + + // keep rxlimit + if len(dec.rx) > dec.rxlimit { + if dec.rx[0].flag() == typeData { // track the unrecoverable data + atomic.AddUint64(&DefaultSnmp.FECShortShards, 1) + } + dec.rx = dec.freeRange(0, 1, dec.rx) + } + + // timeout policy + current := currentMs() + numExpired := 0 + for k := range dec.rx { + if _itimediff(current, dec.rx[k].ts) > fecExpire { + numExpired++ + continue + } + break + } + if numExpired > 0 { + dec.rx = dec.freeRange(0, numExpired, dec.rx) + } + return +} + +// free a range of fecPacket +func (dec *fecDecoder) freeRange(first, n int, q []fecElement) []fecElement { + for i := first; i < first+n; i++ { // recycle buffer + xmitBuf.Put([]byte(q[i].fecPacket)) + } + + if first == 0 && n < cap(q)/2 { + return q[n:] + } + copy(q[first:], q[first+n:]) + return q[:len(q)-n] +} + +// release all segments back to xmitBuf +func (dec *fecDecoder) release() { + if n := len(dec.rx); n > 0 { + dec.rx = dec.freeRange(0, n, dec.rx) + } +} + +type ( + // fecEncoder for encoding outgoing packets + fecEncoder struct { + dataShards int + parityShards int + shardSize int + paws uint32 // Protect Against Wrapped Sequence numbers + next uint32 // next seqid + + shardCount int // count the number of datashards collected + maxSize int // track maximum data length in datashard + + headerOffset int // FEC header offset + payloadOffset int // FEC payload offset + + // caches + shardCache [][]byte + encodeCache [][]byte + + // zeros + zeros []byte + + // RS encoder + codec reedsolomon.Encoder + } +) + +func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder { + if dataShards <= 0 || parityShards <= 0 { + return nil + } + enc := new(fecEncoder) + enc.dataShards = dataShards + enc.parityShards = parityShards + enc.shardSize = dataShards + parityShards + enc.paws = 0xffffffff / uint32(enc.shardSize) * uint32(enc.shardSize) + enc.headerOffset = offset + enc.payloadOffset = enc.headerOffset + fecHeaderSize + + codec, err := reedsolomon.New(dataShards, parityShards) + if err != nil { + return nil + } + enc.codec = codec + + // caches + enc.encodeCache = make([][]byte, enc.shardSize) + enc.shardCache = make([][]byte, enc.shardSize) + for k := range enc.shardCache { + enc.shardCache[k] = make([]byte, mtuLimit) + } + enc.zeros = make([]byte, mtuLimit) + return enc +} + +// encodes the packet, outputs parity shards if we have collected quorum datashards +// notice: the contents of 'ps' will be re-written in successive calling +func (enc *fecEncoder) encode(b []byte) (ps [][]byte) { + // The header format: + // | FEC SEQID(4B) | FEC TYPE(2B) | SIZE (2B) | PAYLOAD(SIZE-2) | + // |<-headerOffset |<-payloadOffset + enc.markData(b[enc.headerOffset:]) + binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:]))) + + // copy data from payloadOffset to fec shard cache + sz := len(b) + enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz] + copy(enc.shardCache[enc.shardCount][enc.payloadOffset:], b[enc.payloadOffset:]) + enc.shardCount++ + + // track max datashard length + if sz > enc.maxSize { + enc.maxSize = sz + } + + // Generation of Reed-Solomon Erasure Code + if enc.shardCount == enc.dataShards { + // fill '0' into the tail of each datashard + for i := 0; i < enc.dataShards; i++ { + shard := enc.shardCache[i] + slen := len(shard) + copy(shard[slen:enc.maxSize], enc.zeros) + } + + // construct equal-sized slice with stripped header + cache := enc.encodeCache + for k := range cache { + cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize] + } + + // encoding + if err := enc.codec.Encode(cache); err == nil { + ps = enc.shardCache[enc.dataShards:] + for k := range ps { + enc.markParity(ps[k][enc.headerOffset:]) + ps[k] = ps[k][:enc.maxSize] + } + } + + // counters resetting + enc.shardCount = 0 + enc.maxSize = 0 + } + + return +} + +func (enc *fecEncoder) markData(data []byte) { + binary.LittleEndian.PutUint32(data, enc.next) + binary.LittleEndian.PutUint16(data[4:], typeData) + enc.next++ +} + +func (enc *fecEncoder) markParity(data []byte) { + binary.LittleEndian.PutUint32(data, enc.next) + binary.LittleEndian.PutUint16(data[4:], typeParity) + // sequence wrap will only happen at parity shard + enc.next = (enc.next + 1) % enc.paws +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/fec_test.go b/vendor/github.com/xtaci/kcp-go/v5/fec_test.go new file mode 100644 index 00000000..59b64aca --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/fec_test.go @@ -0,0 +1,43 @@ +package kcp + +import ( + "encoding/binary" + "math/rand" + "testing" +) + +func BenchmarkFECDecode(b *testing.B) { + const dataSize = 10 + const paritySize = 3 + const payLoad = 1500 + decoder := newFECDecoder(dataSize, paritySize) + b.ReportAllocs() + b.SetBytes(payLoad) + for i := 0; i < b.N; i++ { + if rand.Int()%(dataSize+paritySize) == 0 { // random loss + continue + } + pkt := make([]byte, payLoad) + binary.LittleEndian.PutUint32(pkt, uint32(i)) + if i%(dataSize+paritySize) >= dataSize { + binary.LittleEndian.PutUint16(pkt[4:], typeParity) + } else { + binary.LittleEndian.PutUint16(pkt[4:], typeData) + } + decoder.decode(pkt) + } +} + +func BenchmarkFECEncode(b *testing.B) { + const dataSize = 10 + const paritySize = 3 + const payLoad = 1500 + + b.ReportAllocs() + b.SetBytes(payLoad) + encoder := newFECEncoder(dataSize, paritySize, 0) + for i := 0; i < b.N; i++ { + data := make([]byte, payLoad) + encoder.encode(data) + } +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/flame.png b/vendor/github.com/xtaci/kcp-go/v5/flame.png new file mode 100644 index 00000000..672f649e Binary files /dev/null and b/vendor/github.com/xtaci/kcp-go/v5/flame.png differ diff --git a/vendor/github.com/xtaci/kcp-go/v5/frame.png b/vendor/github.com/xtaci/kcp-go/v5/frame.png new file mode 100644 index 00000000..0b0aefd4 Binary files /dev/null and b/vendor/github.com/xtaci/kcp-go/v5/frame.png differ diff --git a/vendor/github.com/xtaci/kcp-go/v5/go.mod b/vendor/github.com/xtaci/kcp-go/v5/go.mod new file mode 100644 index 00000000..0247f820 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/go.mod @@ -0,0 +1,20 @@ +module github.com/xtaci/kcp-go/v5 + +require ( + github.com/klauspost/cpuid v1.3.1 // indirect + github.com/klauspost/reedsolomon v1.9.9 + github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 // indirect + github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.6.1 + github.com/templexxx/cpu v0.0.8 // indirect + github.com/templexxx/xorsimd v0.4.1 + github.com/tjfoc/gmsm v1.3.2 + github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae + golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de + golang.org/x/net v0.0.0-20200707034311-ab3426394381 + golang.org/x/sys v0.0.0-20200808120158-1030fc2bf1d9 // indirect + golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect +) + +go 1.13 diff --git a/vendor/github.com/xtaci/kcp-go/v5/go.sum b/vendor/github.com/xtaci/kcp-go/v5/go.sum new file mode 100644 index 00000000..d583de7c --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/go.sum @@ -0,0 +1,71 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/klauspost/cpuid v1.2.4/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s= +github.com/klauspost/cpuid v1.3.1/go.mod h1:bYW4mA6ZgKPob1/Dlai2LviZJO7KGI3uoWLd42rAQw4= +github.com/klauspost/reedsolomon v1.9.9 h1:qCL7LZlv17xMixl55nq2/Oa1Y86nfO8EqDfv2GHND54= +github.com/klauspost/reedsolomon v1.9.9/go.mod h1:O7yFFHiQwDR6b2t63KPUpccPtNdp5ADgh1gg4fd12wo= +github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 h1:ULR/QWMgcgRiZLUjSSJMU+fW+RDMstRdmnDWj9Q+AsA= +github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104/go.mod h1:wqKykBG2QzQDJEzvRkcS8x6MiSJkF52hXZsXcjaB3ls= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/templexxx/cpu v0.0.1 h1:hY4WdLOgKdc8y13EYklu9OUTXik80BkxHoWvTO6MQQY= +github.com/templexxx/cpu v0.0.1/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= +github.com/templexxx/cpu v0.0.7 h1:pUEZn8JBy/w5yzdYWgx+0m0xL9uk6j4K91C5kOViAzo= +github.com/templexxx/cpu v0.0.7/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= +github.com/templexxx/cpu v0.0.8 h1:va6GebSxedVdR5XEyPJD49t94p5ZsjWO6Wh/PfbmZnc= +github.com/templexxx/cpu v0.0.8/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= +github.com/templexxx/xorsimd v0.4.1 h1:iUZcywbOYDRAZUasAs2eSCUW8eobuZDy0I9FJiORkVg= +github.com/templexxx/xorsimd v0.4.1/go.mod h1:W+ffZz8jJMH2SXwuKu9WhygqBMbFnp14G2fqEr8qaNo= +github.com/tjfoc/gmsm v1.3.2 h1:7JVkAn5bvUJ7HtU08iW6UiD+UTmJTIToHCfeFzkcCxM= +github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w= +github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM= +github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae/go.mod h1:gXtu8J62kEgmN++bm9BVICuT/e8yiLI2KFobd/TRFsE= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/arch v0.0.0-20190909030613-46d78d1859ac/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191219195013-becbf705a915/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de h1:ikNHVSjEfnvz6sxdSPCaPt572qowuyMDMJLLm3Db3ig= +golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200808120158-1030fc2bf1d9 h1:yi1hN8dcqI9l8klZfy4B8mJvFmmAxJEePIQQFNSd7Cs= +golang.org/x/sys v0.0.0-20200808120158-1030fc2bf1d9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU= +golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123 h1:4JSJPND/+4555t1HfXYF4UEqDqiSKCgeV0+hbA8hMs4= +golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/vendor/github.com/xtaci/kcp-go/v5/kcp-go.png b/vendor/github.com/xtaci/kcp-go/v5/kcp-go.png new file mode 100644 index 00000000..151b7c4f Binary files /dev/null and b/vendor/github.com/xtaci/kcp-go/v5/kcp-go.png differ diff --git a/vendor/github.com/xtaci/kcp-go/v5/kcp.go b/vendor/github.com/xtaci/kcp-go/v5/kcp.go new file mode 100644 index 00000000..0c6c304f --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/kcp.go @@ -0,0 +1,1080 @@ +package kcp + +import ( + "encoding/binary" + "sync/atomic" + "time" +) + +const ( + IKCP_RTO_NDL = 30 // no delay min rto + IKCP_RTO_MIN = 100 // normal min rto + IKCP_RTO_DEF = 200 + IKCP_RTO_MAX = 60000 + IKCP_CMD_PUSH = 81 // cmd: push data + IKCP_CMD_ACK = 82 // cmd: ack + IKCP_CMD_WASK = 83 // cmd: window probe (ask) + IKCP_CMD_WINS = 84 // cmd: window size (tell) + IKCP_ASK_SEND = 1 // need to send IKCP_CMD_WASK + IKCP_ASK_TELL = 2 // need to send IKCP_CMD_WINS + IKCP_WND_SND = 32 + IKCP_WND_RCV = 32 + IKCP_MTU_DEF = 1400 + IKCP_ACK_FAST = 3 + IKCP_INTERVAL = 100 + IKCP_OVERHEAD = 24 + IKCP_DEADLINK = 20 + IKCP_THRESH_INIT = 2 + IKCP_THRESH_MIN = 2 + IKCP_PROBE_INIT = 7000 // 7 secs to probe window size + IKCP_PROBE_LIMIT = 120000 // up to 120 secs to probe window + IKCP_SN_OFFSET = 12 +) + +// monotonic reference time point +var refTime time.Time = time.Now() + +// currentMs returns current elapsed monotonic milliseconds since program startup +func currentMs() uint32 { return uint32(time.Since(refTime) / time.Millisecond) } + +// output_callback is a prototype which ought capture conn and call conn.Write +type output_callback func(buf []byte, size int) + +/* encode 8 bits unsigned int */ +func ikcp_encode8u(p []byte, c byte) []byte { + p[0] = c + return p[1:] +} + +/* decode 8 bits unsigned int */ +func ikcp_decode8u(p []byte, c *byte) []byte { + *c = p[0] + return p[1:] +} + +/* encode 16 bits unsigned int (lsb) */ +func ikcp_encode16u(p []byte, w uint16) []byte { + binary.LittleEndian.PutUint16(p, w) + return p[2:] +} + +/* decode 16 bits unsigned int (lsb) */ +func ikcp_decode16u(p []byte, w *uint16) []byte { + *w = binary.LittleEndian.Uint16(p) + return p[2:] +} + +/* encode 32 bits unsigned int (lsb) */ +func ikcp_encode32u(p []byte, l uint32) []byte { + binary.LittleEndian.PutUint32(p, l) + return p[4:] +} + +/* decode 32 bits unsigned int (lsb) */ +func ikcp_decode32u(p []byte, l *uint32) []byte { + *l = binary.LittleEndian.Uint32(p) + return p[4:] +} + +func _imin_(a, b uint32) uint32 { + if a <= b { + return a + } + return b +} + +func _imax_(a, b uint32) uint32 { + if a >= b { + return a + } + return b +} + +func _ibound_(lower, middle, upper uint32) uint32 { + return _imin_(_imax_(lower, middle), upper) +} + +func _itimediff(later, earlier uint32) int32 { + return (int32)(later - earlier) +} + +// segment defines a KCP segment +type segment struct { + conv uint32 + cmd uint8 + frg uint8 + wnd uint16 + ts uint32 + sn uint32 + una uint32 + rto uint32 + xmit uint32 + resendts uint32 + fastack uint32 + acked uint32 // mark if the seg has acked + data []byte +} + +// encode a segment into buffer +func (seg *segment) encode(ptr []byte) []byte { + ptr = ikcp_encode32u(ptr, seg.conv) + ptr = ikcp_encode8u(ptr, seg.cmd) + ptr = ikcp_encode8u(ptr, seg.frg) + ptr = ikcp_encode16u(ptr, seg.wnd) + ptr = ikcp_encode32u(ptr, seg.ts) + ptr = ikcp_encode32u(ptr, seg.sn) + ptr = ikcp_encode32u(ptr, seg.una) + ptr = ikcp_encode32u(ptr, uint32(len(seg.data))) + atomic.AddUint64(&DefaultSnmp.OutSegs, 1) + return ptr +} + +// KCP defines a single KCP connection +type KCP struct { + conv, mtu, mss, state uint32 + snd_una, snd_nxt, rcv_nxt uint32 + ssthresh uint32 + rx_rttvar, rx_srtt int32 + rx_rto, rx_minrto uint32 + snd_wnd, rcv_wnd, rmt_wnd, cwnd, probe uint32 + interval, ts_flush uint32 + nodelay, updated uint32 + ts_probe, probe_wait uint32 + dead_link, incr uint32 + + fastresend int32 + nocwnd, stream int32 + + snd_queue []segment + rcv_queue []segment + snd_buf []segment + rcv_buf []segment + + acklist []ackItem + + buffer []byte + reserved int + output output_callback +} + +type ackItem struct { + sn uint32 + ts uint32 +} + +// NewKCP create a new kcp state machine +// +// 'conv' must be equal in the connection peers, or else data will be silently rejected. +// +// 'output' function will be called whenever these is data to be sent on wire. +func NewKCP(conv uint32, output output_callback) *KCP { + kcp := new(KCP) + kcp.conv = conv + kcp.snd_wnd = IKCP_WND_SND + kcp.rcv_wnd = IKCP_WND_RCV + kcp.rmt_wnd = IKCP_WND_RCV + kcp.mtu = IKCP_MTU_DEF + kcp.mss = kcp.mtu - IKCP_OVERHEAD + kcp.buffer = make([]byte, kcp.mtu) + kcp.rx_rto = IKCP_RTO_DEF + kcp.rx_minrto = IKCP_RTO_MIN + kcp.interval = IKCP_INTERVAL + kcp.ts_flush = IKCP_INTERVAL + kcp.ssthresh = IKCP_THRESH_INIT + kcp.dead_link = IKCP_DEADLINK + kcp.output = output + return kcp +} + +// newSegment creates a KCP segment +func (kcp *KCP) newSegment(size int) (seg segment) { + seg.data = xmitBuf.Get().([]byte)[:size] + return +} + +// delSegment recycles a KCP segment +func (kcp *KCP) delSegment(seg *segment) { + if seg.data != nil { + xmitBuf.Put(seg.data) + seg.data = nil + } +} + +// ReserveBytes keeps n bytes untouched from the beginning of the buffer, +// the output_callback function should be aware of this. +// +// Return false if n >= mss +func (kcp *KCP) ReserveBytes(n int) bool { + if n >= int(kcp.mtu-IKCP_OVERHEAD) || n < 0 { + return false + } + kcp.reserved = n + kcp.mss = kcp.mtu - IKCP_OVERHEAD - uint32(n) + return true +} + +// PeekSize checks the size of next message in the recv queue +func (kcp *KCP) PeekSize() (length int) { + if len(kcp.rcv_queue) == 0 { + return -1 + } + + seg := &kcp.rcv_queue[0] + if seg.frg == 0 { + return len(seg.data) + } + + if len(kcp.rcv_queue) < int(seg.frg+1) { + return -1 + } + + for k := range kcp.rcv_queue { + seg := &kcp.rcv_queue[k] + length += len(seg.data) + if seg.frg == 0 { + break + } + } + return +} + +// Receive data from kcp state machine +// +// Return number of bytes read. +// +// Return -1 when there is no readable data. +// +// Return -2 if len(buffer) is smaller than kcp.PeekSize(). +func (kcp *KCP) Recv(buffer []byte) (n int) { + peeksize := kcp.PeekSize() + if peeksize < 0 { + return -1 + } + + if peeksize > len(buffer) { + return -2 + } + + var fast_recover bool + if len(kcp.rcv_queue) >= int(kcp.rcv_wnd) { + fast_recover = true + } + + // merge fragment + count := 0 + for k := range kcp.rcv_queue { + seg := &kcp.rcv_queue[k] + copy(buffer, seg.data) + buffer = buffer[len(seg.data):] + n += len(seg.data) + count++ + kcp.delSegment(seg) + if seg.frg == 0 { + break + } + } + if count > 0 { + kcp.rcv_queue = kcp.remove_front(kcp.rcv_queue, count) + } + + // move available data from rcv_buf -> rcv_queue + count = 0 + for k := range kcp.rcv_buf { + seg := &kcp.rcv_buf[k] + if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue)+count < int(kcp.rcv_wnd) { + kcp.rcv_nxt++ + count++ + } else { + break + } + } + + if count > 0 { + kcp.rcv_queue = append(kcp.rcv_queue, kcp.rcv_buf[:count]...) + kcp.rcv_buf = kcp.remove_front(kcp.rcv_buf, count) + } + + // fast recover + if len(kcp.rcv_queue) < int(kcp.rcv_wnd) && fast_recover { + // ready to send back IKCP_CMD_WINS in ikcp_flush + // tell remote my window size + kcp.probe |= IKCP_ASK_TELL + } + return +} + +// Send is user/upper level send, returns below zero for error +func (kcp *KCP) Send(buffer []byte) int { + var count int + if len(buffer) == 0 { + return -1 + } + + // append to previous segment in streaming mode (if possible) + if kcp.stream != 0 { + n := len(kcp.snd_queue) + if n > 0 { + seg := &kcp.snd_queue[n-1] + if len(seg.data) < int(kcp.mss) { + capacity := int(kcp.mss) - len(seg.data) + extend := capacity + if len(buffer) < capacity { + extend = len(buffer) + } + + // grow slice, the underlying cap is guaranteed to + // be larger than kcp.mss + oldlen := len(seg.data) + seg.data = seg.data[:oldlen+extend] + copy(seg.data[oldlen:], buffer) + buffer = buffer[extend:] + } + } + + if len(buffer) == 0 { + return 0 + } + } + + if len(buffer) <= int(kcp.mss) { + count = 1 + } else { + count = (len(buffer) + int(kcp.mss) - 1) / int(kcp.mss) + } + + if count > 255 { + return -2 + } + + if count == 0 { + count = 1 + } + + for i := 0; i < count; i++ { + var size int + if len(buffer) > int(kcp.mss) { + size = int(kcp.mss) + } else { + size = len(buffer) + } + seg := kcp.newSegment(size) + copy(seg.data, buffer[:size]) + if kcp.stream == 0 { // message mode + seg.frg = uint8(count - i - 1) + } else { // stream mode + seg.frg = 0 + } + kcp.snd_queue = append(kcp.snd_queue, seg) + buffer = buffer[size:] + } + return 0 +} + +func (kcp *KCP) update_ack(rtt int32) { + // https://tools.ietf.org/html/rfc6298 + var rto uint32 + if kcp.rx_srtt == 0 { + kcp.rx_srtt = rtt + kcp.rx_rttvar = rtt >> 1 + } else { + delta := rtt - kcp.rx_srtt + kcp.rx_srtt += delta >> 3 + if delta < 0 { + delta = -delta + } + if rtt < kcp.rx_srtt-kcp.rx_rttvar { + // if the new RTT sample is below the bottom of the range of + // what an RTT measurement is expected to be. + // give an 8x reduced weight versus its normal weighting + kcp.rx_rttvar += (delta - kcp.rx_rttvar) >> 5 + } else { + kcp.rx_rttvar += (delta - kcp.rx_rttvar) >> 2 + } + } + rto = uint32(kcp.rx_srtt) + _imax_(kcp.interval, uint32(kcp.rx_rttvar)<<2) + kcp.rx_rto = _ibound_(kcp.rx_minrto, rto, IKCP_RTO_MAX) +} + +func (kcp *KCP) shrink_buf() { + if len(kcp.snd_buf) > 0 { + seg := &kcp.snd_buf[0] + kcp.snd_una = seg.sn + } else { + kcp.snd_una = kcp.snd_nxt + } +} + +func (kcp *KCP) parse_ack(sn uint32) { + if _itimediff(sn, kcp.snd_una) < 0 || _itimediff(sn, kcp.snd_nxt) >= 0 { + return + } + + for k := range kcp.snd_buf { + seg := &kcp.snd_buf[k] + if sn == seg.sn { + // mark and free space, but leave the segment here, + // and wait until `una` to delete this, then we don't + // have to shift the segments behind forward, + // which is an expensive operation for large window + seg.acked = 1 + kcp.delSegment(seg) + break + } + if _itimediff(sn, seg.sn) < 0 { + break + } + } +} + +func (kcp *KCP) parse_fastack(sn, ts uint32) { + if _itimediff(sn, kcp.snd_una) < 0 || _itimediff(sn, kcp.snd_nxt) >= 0 { + return + } + + for k := range kcp.snd_buf { + seg := &kcp.snd_buf[k] + if _itimediff(sn, seg.sn) < 0 { + break + } else if sn != seg.sn && _itimediff(seg.ts, ts) <= 0 { + seg.fastack++ + } + } +} + +func (kcp *KCP) parse_una(una uint32) int { + count := 0 + for k := range kcp.snd_buf { + seg := &kcp.snd_buf[k] + if _itimediff(una, seg.sn) > 0 { + kcp.delSegment(seg) + count++ + } else { + break + } + } + if count > 0 { + kcp.snd_buf = kcp.remove_front(kcp.snd_buf, count) + } + return count +} + +// ack append +func (kcp *KCP) ack_push(sn, ts uint32) { + kcp.acklist = append(kcp.acklist, ackItem{sn, ts}) +} + +// returns true if data has repeated +func (kcp *KCP) parse_data(newseg segment) bool { + sn := newseg.sn + if _itimediff(sn, kcp.rcv_nxt+kcp.rcv_wnd) >= 0 || + _itimediff(sn, kcp.rcv_nxt) < 0 { + return true + } + + n := len(kcp.rcv_buf) - 1 + insert_idx := 0 + repeat := false + for i := n; i >= 0; i-- { + seg := &kcp.rcv_buf[i] + if seg.sn == sn { + repeat = true + break + } + if _itimediff(sn, seg.sn) > 0 { + insert_idx = i + 1 + break + } + } + + if !repeat { + // replicate the content if it's new + dataCopy := xmitBuf.Get().([]byte)[:len(newseg.data)] + copy(dataCopy, newseg.data) + newseg.data = dataCopy + + if insert_idx == n+1 { + kcp.rcv_buf = append(kcp.rcv_buf, newseg) + } else { + kcp.rcv_buf = append(kcp.rcv_buf, segment{}) + copy(kcp.rcv_buf[insert_idx+1:], kcp.rcv_buf[insert_idx:]) + kcp.rcv_buf[insert_idx] = newseg + } + } + + // move available data from rcv_buf -> rcv_queue + count := 0 + for k := range kcp.rcv_buf { + seg := &kcp.rcv_buf[k] + if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue)+count < int(kcp.rcv_wnd) { + kcp.rcv_nxt++ + count++ + } else { + break + } + } + if count > 0 { + kcp.rcv_queue = append(kcp.rcv_queue, kcp.rcv_buf[:count]...) + kcp.rcv_buf = kcp.remove_front(kcp.rcv_buf, count) + } + + return repeat +} + +// Input a packet into kcp state machine. +// +// 'regular' indicates it's a real data packet from remote, and it means it's not generated from ReedSolomon +// codecs. +// +// 'ackNoDelay' will trigger immediate ACK, but surely it will not be efficient in bandwidth +func (kcp *KCP) Input(data []byte, regular, ackNoDelay bool) int { + snd_una := kcp.snd_una + if len(data) < IKCP_OVERHEAD { + return -1 + } + + var latest uint32 // the latest ack packet + var flag int + var inSegs uint64 + var windowSlides bool + + for { + var ts, sn, length, una, conv uint32 + var wnd uint16 + var cmd, frg uint8 + + if len(data) < int(IKCP_OVERHEAD) { + break + } + + data = ikcp_decode32u(data, &conv) + if conv != kcp.conv { + return -1 + } + + data = ikcp_decode8u(data, &cmd) + data = ikcp_decode8u(data, &frg) + data = ikcp_decode16u(data, &wnd) + data = ikcp_decode32u(data, &ts) + data = ikcp_decode32u(data, &sn) + data = ikcp_decode32u(data, &una) + data = ikcp_decode32u(data, &length) + if len(data) < int(length) { + return -2 + } + + if cmd != IKCP_CMD_PUSH && cmd != IKCP_CMD_ACK && + cmd != IKCP_CMD_WASK && cmd != IKCP_CMD_WINS { + return -3 + } + + // only trust window updates from regular packets. i.e: latest update + if regular { + kcp.rmt_wnd = uint32(wnd) + } + if kcp.parse_una(una) > 0 { + windowSlides = true + } + kcp.shrink_buf() + + if cmd == IKCP_CMD_ACK { + kcp.parse_ack(sn) + kcp.parse_fastack(sn, ts) + flag |= 1 + latest = ts + } else if cmd == IKCP_CMD_PUSH { + repeat := true + if _itimediff(sn, kcp.rcv_nxt+kcp.rcv_wnd) < 0 { + kcp.ack_push(sn, ts) + if _itimediff(sn, kcp.rcv_nxt) >= 0 { + var seg segment + seg.conv = conv + seg.cmd = cmd + seg.frg = frg + seg.wnd = wnd + seg.ts = ts + seg.sn = sn + seg.una = una + seg.data = data[:length] // delayed data copying + repeat = kcp.parse_data(seg) + } + } + if regular && repeat { + atomic.AddUint64(&DefaultSnmp.RepeatSegs, 1) + } + } else if cmd == IKCP_CMD_WASK { + // ready to send back IKCP_CMD_WINS in Ikcp_flush + // tell remote my window size + kcp.probe |= IKCP_ASK_TELL + } else if cmd == IKCP_CMD_WINS { + // do nothing + } else { + return -3 + } + + inSegs++ + data = data[length:] + } + atomic.AddUint64(&DefaultSnmp.InSegs, inSegs) + + // update rtt with the latest ts + // ignore the FEC packet + if flag != 0 && regular { + current := currentMs() + if _itimediff(current, latest) >= 0 { + kcp.update_ack(_itimediff(current, latest)) + } + } + + // cwnd update when packet arrived + if kcp.nocwnd == 0 { + if _itimediff(kcp.snd_una, snd_una) > 0 { + if kcp.cwnd < kcp.rmt_wnd { + mss := kcp.mss + if kcp.cwnd < kcp.ssthresh { + kcp.cwnd++ + kcp.incr += mss + } else { + if kcp.incr < mss { + kcp.incr = mss + } + kcp.incr += (mss*mss)/kcp.incr + (mss / 16) + if (kcp.cwnd+1)*mss <= kcp.incr { + if mss > 0 { + kcp.cwnd = (kcp.incr + mss - 1) / mss + } else { + kcp.cwnd = kcp.incr + mss - 1 + } + } + } + if kcp.cwnd > kcp.rmt_wnd { + kcp.cwnd = kcp.rmt_wnd + kcp.incr = kcp.rmt_wnd * mss + } + } + } + } + + if windowSlides { // if window has slided, flush + kcp.flush(false) + } else if ackNoDelay && len(kcp.acklist) > 0 { // ack immediately + kcp.flush(true) + } + return 0 +} + +func (kcp *KCP) wnd_unused() uint16 { + if len(kcp.rcv_queue) < int(kcp.rcv_wnd) { + return uint16(int(kcp.rcv_wnd) - len(kcp.rcv_queue)) + } + return 0 +} + +// flush pending data +func (kcp *KCP) flush(ackOnly bool) uint32 { + var seg segment + seg.conv = kcp.conv + seg.cmd = IKCP_CMD_ACK + seg.wnd = kcp.wnd_unused() + seg.una = kcp.rcv_nxt + + buffer := kcp.buffer + ptr := buffer[kcp.reserved:] // keep n bytes untouched + + // makeSpace makes room for writing + makeSpace := func(space int) { + size := len(buffer) - len(ptr) + if size+space > int(kcp.mtu) { + kcp.output(buffer, size) + ptr = buffer[kcp.reserved:] + } + } + + // flush bytes in buffer if there is any + flushBuffer := func() { + size := len(buffer) - len(ptr) + if size > kcp.reserved { + kcp.output(buffer, size) + } + } + + // flush acknowledges + for i, ack := range kcp.acklist { + makeSpace(IKCP_OVERHEAD) + // filter jitters caused by bufferbloat + if _itimediff(ack.sn, kcp.rcv_nxt) >= 0 || len(kcp.acklist)-1 == i { + seg.sn, seg.ts = ack.sn, ack.ts + ptr = seg.encode(ptr) + } + } + kcp.acklist = kcp.acklist[0:0] + + if ackOnly { // flash remain ack segments + flushBuffer() + return kcp.interval + } + + // probe window size (if remote window size equals zero) + if kcp.rmt_wnd == 0 { + current := currentMs() + if kcp.probe_wait == 0 { + kcp.probe_wait = IKCP_PROBE_INIT + kcp.ts_probe = current + kcp.probe_wait + } else { + if _itimediff(current, kcp.ts_probe) >= 0 { + if kcp.probe_wait < IKCP_PROBE_INIT { + kcp.probe_wait = IKCP_PROBE_INIT + } + kcp.probe_wait += kcp.probe_wait / 2 + if kcp.probe_wait > IKCP_PROBE_LIMIT { + kcp.probe_wait = IKCP_PROBE_LIMIT + } + kcp.ts_probe = current + kcp.probe_wait + kcp.probe |= IKCP_ASK_SEND + } + } + } else { + kcp.ts_probe = 0 + kcp.probe_wait = 0 + } + + // flush window probing commands + if (kcp.probe & IKCP_ASK_SEND) != 0 { + seg.cmd = IKCP_CMD_WASK + makeSpace(IKCP_OVERHEAD) + ptr = seg.encode(ptr) + } + + // flush window probing commands + if (kcp.probe & IKCP_ASK_TELL) != 0 { + seg.cmd = IKCP_CMD_WINS + makeSpace(IKCP_OVERHEAD) + ptr = seg.encode(ptr) + } + + kcp.probe = 0 + + // calculate window size + cwnd := _imin_(kcp.snd_wnd, kcp.rmt_wnd) + if kcp.nocwnd == 0 { + cwnd = _imin_(kcp.cwnd, cwnd) + } + + // sliding window, controlled by snd_nxt && sna_una+cwnd + newSegsCount := 0 + for k := range kcp.snd_queue { + if _itimediff(kcp.snd_nxt, kcp.snd_una+cwnd) >= 0 { + break + } + newseg := kcp.snd_queue[k] + newseg.conv = kcp.conv + newseg.cmd = IKCP_CMD_PUSH + newseg.sn = kcp.snd_nxt + kcp.snd_buf = append(kcp.snd_buf, newseg) + kcp.snd_nxt++ + newSegsCount++ + } + if newSegsCount > 0 { + kcp.snd_queue = kcp.remove_front(kcp.snd_queue, newSegsCount) + } + + // calculate resent + resent := uint32(kcp.fastresend) + if kcp.fastresend <= 0 { + resent = 0xffffffff + } + + // check for retransmissions + current := currentMs() + var change, lostSegs, fastRetransSegs, earlyRetransSegs uint64 + minrto := int32(kcp.interval) + + ref := kcp.snd_buf[:len(kcp.snd_buf)] // for bounds check elimination + for k := range ref { + segment := &ref[k] + needsend := false + if segment.acked == 1 { + continue + } + if segment.xmit == 0 { // initial transmit + needsend = true + segment.rto = kcp.rx_rto + segment.resendts = current + segment.rto + } else if segment.fastack >= resent { // fast retransmit + needsend = true + segment.fastack = 0 + segment.rto = kcp.rx_rto + segment.resendts = current + segment.rto + change++ + fastRetransSegs++ + } else if segment.fastack > 0 && newSegsCount == 0 { // early retransmit + needsend = true + segment.fastack = 0 + segment.rto = kcp.rx_rto + segment.resendts = current + segment.rto + change++ + earlyRetransSegs++ + } else if _itimediff(current, segment.resendts) >= 0 { // RTO + needsend = true + if kcp.nodelay == 0 { + segment.rto += kcp.rx_rto + } else { + segment.rto += kcp.rx_rto / 2 + } + segment.fastack = 0 + segment.resendts = current + segment.rto + lostSegs++ + } + + if needsend { + current = currentMs() + segment.xmit++ + segment.ts = current + segment.wnd = seg.wnd + segment.una = seg.una + + need := IKCP_OVERHEAD + len(segment.data) + makeSpace(need) + ptr = segment.encode(ptr) + copy(ptr, segment.data) + ptr = ptr[len(segment.data):] + + if segment.xmit >= kcp.dead_link { + kcp.state = 0xFFFFFFFF + } + } + + // get the nearest rto + if rto := _itimediff(segment.resendts, current); rto > 0 && rto < minrto { + minrto = rto + } + } + + // flash remain segments + flushBuffer() + + // counter updates + sum := lostSegs + if lostSegs > 0 { + atomic.AddUint64(&DefaultSnmp.LostSegs, lostSegs) + } + if fastRetransSegs > 0 { + atomic.AddUint64(&DefaultSnmp.FastRetransSegs, fastRetransSegs) + sum += fastRetransSegs + } + if earlyRetransSegs > 0 { + atomic.AddUint64(&DefaultSnmp.EarlyRetransSegs, earlyRetransSegs) + sum += earlyRetransSegs + } + if sum > 0 { + atomic.AddUint64(&DefaultSnmp.RetransSegs, sum) + } + + // cwnd update + if kcp.nocwnd == 0 { + // update ssthresh + // rate halving, https://tools.ietf.org/html/rfc6937 + if change > 0 { + inflight := kcp.snd_nxt - kcp.snd_una + kcp.ssthresh = inflight / 2 + if kcp.ssthresh < IKCP_THRESH_MIN { + kcp.ssthresh = IKCP_THRESH_MIN + } + kcp.cwnd = kcp.ssthresh + resent + kcp.incr = kcp.cwnd * kcp.mss + } + + // congestion control, https://tools.ietf.org/html/rfc5681 + if lostSegs > 0 { + kcp.ssthresh = cwnd / 2 + if kcp.ssthresh < IKCP_THRESH_MIN { + kcp.ssthresh = IKCP_THRESH_MIN + } + kcp.cwnd = 1 + kcp.incr = kcp.mss + } + + if kcp.cwnd < 1 { + kcp.cwnd = 1 + kcp.incr = kcp.mss + } + } + + return uint32(minrto) +} + +// (deprecated) +// +// Update updates state (call it repeatedly, every 10ms-100ms), or you can ask +// ikcp_check when to call it again (without ikcp_input/_send calling). +// 'current' - current timestamp in millisec. +func (kcp *KCP) Update() { + var slap int32 + + current := currentMs() + if kcp.updated == 0 { + kcp.updated = 1 + kcp.ts_flush = current + } + + slap = _itimediff(current, kcp.ts_flush) + + if slap >= 10000 || slap < -10000 { + kcp.ts_flush = current + slap = 0 + } + + if slap >= 0 { + kcp.ts_flush += kcp.interval + if _itimediff(current, kcp.ts_flush) >= 0 { + kcp.ts_flush = current + kcp.interval + } + kcp.flush(false) + } +} + +// (deprecated) +// +// Check determines when should you invoke ikcp_update: +// returns when you should invoke ikcp_update in millisec, if there +// is no ikcp_input/_send calling. you can call ikcp_update in that +// time, instead of call update repeatly. +// Important to reduce unnacessary ikcp_update invoking. use it to +// schedule ikcp_update (eg. implementing an epoll-like mechanism, +// or optimize ikcp_update when handling massive kcp connections) +func (kcp *KCP) Check() uint32 { + current := currentMs() + ts_flush := kcp.ts_flush + tm_flush := int32(0x7fffffff) + tm_packet := int32(0x7fffffff) + minimal := uint32(0) + if kcp.updated == 0 { + return current + } + + if _itimediff(current, ts_flush) >= 10000 || + _itimediff(current, ts_flush) < -10000 { + ts_flush = current + } + + if _itimediff(current, ts_flush) >= 0 { + return current + } + + tm_flush = _itimediff(ts_flush, current) + + for k := range kcp.snd_buf { + seg := &kcp.snd_buf[k] + diff := _itimediff(seg.resendts, current) + if diff <= 0 { + return current + } + if diff < tm_packet { + tm_packet = diff + } + } + + minimal = uint32(tm_packet) + if tm_packet >= tm_flush { + minimal = uint32(tm_flush) + } + if minimal >= kcp.interval { + minimal = kcp.interval + } + + return current + minimal +} + +// SetMtu changes MTU size, default is 1400 +func (kcp *KCP) SetMtu(mtu int) int { + if mtu < 50 || mtu < IKCP_OVERHEAD { + return -1 + } + if kcp.reserved >= int(kcp.mtu-IKCP_OVERHEAD) || kcp.reserved < 0 { + return -1 + } + + buffer := make([]byte, mtu) + if buffer == nil { + return -2 + } + kcp.mtu = uint32(mtu) + kcp.mss = kcp.mtu - IKCP_OVERHEAD - uint32(kcp.reserved) + kcp.buffer = buffer + return 0 +} + +// NoDelay options +// fastest: ikcp_nodelay(kcp, 1, 20, 2, 1) +// nodelay: 0:disable(default), 1:enable +// interval: internal update timer interval in millisec, default is 100ms +// resend: 0:disable fast resend(default), 1:enable fast resend +// nc: 0:normal congestion control(default), 1:disable congestion control +func (kcp *KCP) NoDelay(nodelay, interval, resend, nc int) int { + if nodelay >= 0 { + kcp.nodelay = uint32(nodelay) + if nodelay != 0 { + kcp.rx_minrto = IKCP_RTO_NDL + } else { + kcp.rx_minrto = IKCP_RTO_MIN + } + } + if interval >= 0 { + if interval > 5000 { + interval = 5000 + } else if interval < 10 { + interval = 10 + } + kcp.interval = uint32(interval) + } + if resend >= 0 { + kcp.fastresend = int32(resend) + } + if nc >= 0 { + kcp.nocwnd = int32(nc) + } + return 0 +} + +// WndSize sets maximum window size: sndwnd=32, rcvwnd=32 by default +func (kcp *KCP) WndSize(sndwnd, rcvwnd int) int { + if sndwnd > 0 { + kcp.snd_wnd = uint32(sndwnd) + } + if rcvwnd > 0 { + kcp.rcv_wnd = uint32(rcvwnd) + } + return 0 +} + +// WaitSnd gets how many packet is waiting to be sent +func (kcp *KCP) WaitSnd() int { + return len(kcp.snd_buf) + len(kcp.snd_queue) +} + +// remove front n elements from queue +// if the number of elements to remove is more than half of the size. +// just shift the rear elements to front, otherwise just reslice q to q[n:] +// then the cost of runtime.growslice can always be less than n/2 +func (kcp *KCP) remove_front(q []segment, n int) []segment { + if n > cap(q)/2 { + newn := copy(q, q[n:]) + return q[:newn] + } + return q[n:] +} + +// Release all cached outgoing segments +func (kcp *KCP) ReleaseTX() { + for k := range kcp.snd_queue { + if kcp.snd_queue[k].data != nil { + xmitBuf.Put(kcp.snd_queue[k].data) + } + } + for k := range kcp.snd_buf { + if kcp.snd_buf[k].data != nil { + xmitBuf.Put(kcp.snd_buf[k].data) + } + } + kcp.snd_queue = nil + kcp.snd_buf = nil +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/kcp_test.go b/vendor/github.com/xtaci/kcp-go/v5/kcp_test.go new file mode 100644 index 00000000..49d55d5a --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/kcp_test.go @@ -0,0 +1,135 @@ +package kcp + +import ( + "io" + "net" + "sync" + "testing" + "time" + + "github.com/xtaci/lossyconn" +) + +const repeat = 16 + +func TestLossyConn1(t *testing.T) { + t.Log("testing loss rate 10%, rtt 200ms") + t.Log("testing link with nodelay parameters:1 10 2 1") + client, err := lossyconn.NewLossyConn(0.1, 100) + if err != nil { + t.Fatal(err) + } + + server, err := lossyconn.NewLossyConn(0.1, 100) + if err != nil { + t.Fatal(err) + } + testlink(t, client, server, 1, 10, 2, 1) +} + +func TestLossyConn2(t *testing.T) { + t.Log("testing loss rate 20%, rtt 200ms") + t.Log("testing link with nodelay parameters:1 10 2 1") + client, err := lossyconn.NewLossyConn(0.2, 100) + if err != nil { + t.Fatal(err) + } + + server, err := lossyconn.NewLossyConn(0.2, 100) + if err != nil { + t.Fatal(err) + } + testlink(t, client, server, 1, 10, 2, 1) +} + +func TestLossyConn3(t *testing.T) { + t.Log("testing loss rate 30%, rtt 200ms") + t.Log("testing link with nodelay parameters:1 10 2 1") + client, err := lossyconn.NewLossyConn(0.3, 100) + if err != nil { + t.Fatal(err) + } + + server, err := lossyconn.NewLossyConn(0.3, 100) + if err != nil { + t.Fatal(err) + } + testlink(t, client, server, 1, 10, 2, 1) +} + +func TestLossyConn4(t *testing.T) { + t.Log("testing loss rate 10%, rtt 200ms") + t.Log("testing link with nodelay parameters:1 10 2 0") + client, err := lossyconn.NewLossyConn(0.1, 100) + if err != nil { + t.Fatal(err) + } + + server, err := lossyconn.NewLossyConn(0.1, 100) + if err != nil { + t.Fatal(err) + } + testlink(t, client, server, 1, 10, 2, 0) +} + +func testlink(t *testing.T, client *lossyconn.LossyConn, server *lossyconn.LossyConn, nodelay, interval, resend, nc int) { + t.Log("testing with nodelay parameters:", nodelay, interval, resend, nc) + sess, _ := NewConn2(server.LocalAddr(), nil, 0, 0, client) + listener, _ := ServeConn(nil, 0, 0, server) + echoServer := func(l *Listener) { + for { + conn, err := l.AcceptKCP() + if err != nil { + return + } + go func() { + conn.SetNoDelay(nodelay, interval, resend, nc) + buf := make([]byte, 65536) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) + } + }() + } + } + + echoTester := func(s *UDPSession, raddr net.Addr) { + s.SetNoDelay(nodelay, interval, resend, nc) + buf := make([]byte, 64) + var rtt time.Duration + for i := 0; i < repeat; i++ { + start := time.Now() + s.Write(buf) + io.ReadFull(s, buf) + rtt += time.Since(start) + } + + t.Log("client:", client) + t.Log("server:", server) + t.Log("avg rtt:", rtt/repeat) + t.Logf("total time: %v for %v round trip:", rtt, repeat) + } + + go echoServer(listener) + echoTester(sess, server.LocalAddr()) +} + +func BenchmarkFlush(b *testing.B) { + kcp := NewKCP(1, func(buf []byte, size int) {}) + kcp.snd_buf = make([]segment, 1024) + for k := range kcp.snd_buf { + kcp.snd_buf[k].xmit = 1 + kcp.snd_buf[k].resendts = currentMs() + 10000 + } + b.ResetTimer() + b.ReportAllocs() + var mu sync.Mutex + for i := 0; i < b.N; i++ { + mu.Lock() + kcp.flush(false) + mu.Unlock() + } +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/readloop.go b/vendor/github.com/xtaci/kcp-go/v5/readloop.go new file mode 100644 index 00000000..697395ab --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/readloop.go @@ -0,0 +1,39 @@ +package kcp + +import ( + "sync/atomic" + + "github.com/pkg/errors" +) + +func (s *UDPSession) defaultReadLoop() { + buf := make([]byte, mtuLimit) + var src string + for { + if n, addr, err := s.conn.ReadFrom(buf); err == nil { + // make sure the packet is from the same source + if src == "" { // set source address + src = addr.String() + } else if addr.String() != src { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + continue + } + s.packetInput(buf[:n]) + } else { + s.notifyReadError(errors.WithStack(err)) + return + } + } +} + +func (l *Listener) defaultMonitor() { + buf := make([]byte, mtuLimit) + for { + if n, from, err := l.conn.ReadFrom(buf); err == nil { + l.packetInput(buf[:n], from) + } else { + l.notifyReadError(errors.WithStack(err)) + return + } + } +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/readloop_generic.go b/vendor/github.com/xtaci/kcp-go/v5/readloop_generic.go new file mode 100644 index 00000000..5dbe4f44 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/readloop_generic.go @@ -0,0 +1,11 @@ +// +build !linux + +package kcp + +func (s *UDPSession) readLoop() { + s.defaultReadLoop() +} + +func (l *Listener) monitor() { + l.defaultMonitor() +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/readloop_linux.go b/vendor/github.com/xtaci/kcp-go/v5/readloop_linux.go new file mode 100644 index 00000000..be194afb --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/readloop_linux.go @@ -0,0 +1,111 @@ +// +build linux + +package kcp + +import ( + "net" + "os" + "sync/atomic" + + "github.com/pkg/errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// the read loop for a client session +func (s *UDPSession) readLoop() { + // default version + if s.xconn == nil { + s.defaultReadLoop() + return + } + + // x/net version + var src string + msgs := make([]ipv4.Message, batchSize) + for k := range msgs { + msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)} + } + + for { + if count, err := s.xconn.ReadBatch(msgs, 0); err == nil { + for i := 0; i < count; i++ { + msg := &msgs[i] + // make sure the packet is from the same source + if src == "" { // set source address if nil + src = msg.Addr.String() + } else if msg.Addr.String() != src { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + continue + } + + // source and size has validated + s.packetInput(msg.Buffers[0][:msg.N]) + } + } else { + // compatibility issue: + // for linux kernel<=2.6.32, support for sendmmsg is not available + // an error of type os.SyscallError will be returned + if operr, ok := err.(*net.OpError); ok { + if se, ok := operr.Err.(*os.SyscallError); ok { + if se.Syscall == "recvmmsg" { + s.defaultReadLoop() + return + } + } + } + s.notifyReadError(errors.WithStack(err)) + return + } + } +} + +// monitor incoming data for all connections of server +func (l *Listener) monitor() { + var xconn batchConn + if _, ok := l.conn.(*net.UDPConn); ok { + addr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String()) + if err == nil { + if addr.IP.To4() != nil { + xconn = ipv4.NewPacketConn(l.conn) + } else { + xconn = ipv6.NewPacketConn(l.conn) + } + } + } + + // default version + if xconn == nil { + l.defaultMonitor() + return + } + + // x/net version + msgs := make([]ipv4.Message, batchSize) + for k := range msgs { + msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)} + } + + for { + if count, err := xconn.ReadBatch(msgs, 0); err == nil { + for i := 0; i < count; i++ { + msg := &msgs[i] + l.packetInput(msg.Buffers[0][:msg.N], msg.Addr) + } + } else { + // compatibility issue: + // for linux kernel<=2.6.32, support for sendmmsg is not available + // an error of type os.SyscallError will be returned + if operr, ok := err.(*net.OpError); ok { + if se, ok := operr.Err.(*os.SyscallError); ok { + if se.Syscall == "recvmmsg" { + l.defaultMonitor() + return + } + } + } + l.notifyReadError(errors.WithStack(err)) + return + } + } +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/sess.go b/vendor/github.com/xtaci/kcp-go/v5/sess.go new file mode 100644 index 00000000..2dedd745 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/sess.go @@ -0,0 +1,1075 @@ +// Package kcp-go is a Reliable-UDP library for golang. +// +// This library intents to provide a smooth, resilient, ordered, +// error-checked and anonymous delivery of streams over UDP packets. +// +// The interfaces of this package aims to be compatible with +// net.Conn in standard library, but offers powerful features for advanced users. +package kcp + +import ( + "crypto/rand" + "encoding/binary" + "hash/crc32" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pkg/errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +const ( + // 16-bytes nonce for each packet + nonceSize = 16 + + // 4-bytes packet checksum + crcSize = 4 + + // overall crypto header size + cryptHeaderSize = nonceSize + crcSize + + // maximum packet size + mtuLimit = 1500 + + // accept backlog + acceptBacklog = 128 +) + +var ( + errInvalidOperation = errors.New("invalid operation") + errTimeout = errors.New("timeout") +) + +var ( + // a system-wide packet buffer shared among sending, receiving and FEC + // to mitigate high-frequency memory allocation for packets, bytes from xmitBuf + // is aligned to 64bit + xmitBuf sync.Pool +) + +func init() { + xmitBuf.New = func() interface{} { + return make([]byte, mtuLimit) + } +} + +type ( + // UDPSession defines a KCP session implemented by UDP + UDPSession struct { + conn net.PacketConn // the underlying packet connection + ownConn bool // true if we created conn internally, false if provided by caller + kcp *KCP // KCP ARQ protocol + l *Listener // pointing to the Listener object if it's been accepted by a Listener + block BlockCrypt // block encryption object + + // kcp receiving is based on packets + // recvbuf turns packets into stream + recvbuf []byte + bufptr []byte + + // FEC codec + fecDecoder *fecDecoder + fecEncoder *fecEncoder + + // settings + remote net.Addr // remote peer address + rd time.Time // read deadline + wd time.Time // write deadline + headerSize int // the header size additional to a KCP frame + ackNoDelay bool // send ack immediately for each incoming packet(testing purpose) + writeDelay bool // delay kcp.flush() for Write() for bulk transfer + dup int // duplicate udp packets(testing purpose) + + // notifications + die chan struct{} // notify current session has Closed + dieOnce sync.Once + chReadEvent chan struct{} // notify Read() can be called without blocking + chWriteEvent chan struct{} // notify Write() can be called without blocking + + // socket error handling + socketReadError atomic.Value + socketWriteError atomic.Value + chSocketReadError chan struct{} + chSocketWriteError chan struct{} + socketReadErrorOnce sync.Once + socketWriteErrorOnce sync.Once + + // nonce generator + nonce Entropy + + // packets waiting to be sent on wire + txqueue []ipv4.Message + xconn batchConn // for x/net + xconnWriteError error + + mu sync.Mutex + } + + setReadBuffer interface { + SetReadBuffer(bytes int) error + } + + setWriteBuffer interface { + SetWriteBuffer(bytes int) error + } + + setDSCP interface { + SetDSCP(int) error + } +) + +// newUDPSession create a new udp session for client or server +func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn net.PacketConn, ownConn bool, remote net.Addr, block BlockCrypt) *UDPSession { + sess := new(UDPSession) + sess.die = make(chan struct{}) + sess.nonce = new(nonceAES128) + sess.nonce.Init() + sess.chReadEvent = make(chan struct{}, 1) + sess.chWriteEvent = make(chan struct{}, 1) + sess.chSocketReadError = make(chan struct{}) + sess.chSocketWriteError = make(chan struct{}) + sess.remote = remote + sess.conn = conn + sess.ownConn = ownConn + sess.l = l + sess.block = block + sess.recvbuf = make([]byte, mtuLimit) + + // cast to writebatch conn + if _, ok := conn.(*net.UDPConn); ok { + addr, err := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) + if err == nil { + if addr.IP.To4() != nil { + sess.xconn = ipv4.NewPacketConn(conn) + } else { + sess.xconn = ipv6.NewPacketConn(conn) + } + } + } + + // FEC codec initialization + sess.fecDecoder = newFECDecoder(dataShards, parityShards) + if sess.block != nil { + sess.fecEncoder = newFECEncoder(dataShards, parityShards, cryptHeaderSize) + } else { + sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0) + } + + // calculate additional header size introduced by FEC and encryption + if sess.block != nil { + sess.headerSize += cryptHeaderSize + } + if sess.fecEncoder != nil { + sess.headerSize += fecHeaderSizePlus2 + } + + sess.kcp = NewKCP(conv, func(buf []byte, size int) { + if size >= IKCP_OVERHEAD+sess.headerSize { + sess.output(buf[:size]) + } + }) + sess.kcp.ReserveBytes(sess.headerSize) + + if sess.l == nil { // it's a client connection + go sess.readLoop() + atomic.AddUint64(&DefaultSnmp.ActiveOpens, 1) + } else { + atomic.AddUint64(&DefaultSnmp.PassiveOpens, 1) + } + + // start per-session updater + SystemTimedSched.Put(sess.update, time.Now()) + + currestab := atomic.AddUint64(&DefaultSnmp.CurrEstab, 1) + maxconn := atomic.LoadUint64(&DefaultSnmp.MaxConn) + if currestab > maxconn { + atomic.CompareAndSwapUint64(&DefaultSnmp.MaxConn, maxconn, currestab) + } + + return sess +} + +// Read implements net.Conn +func (s *UDPSession) Read(b []byte) (n int, err error) { + for { + s.mu.Lock() + if len(s.bufptr) > 0 { // copy from buffer into b + n = copy(b, s.bufptr) + s.bufptr = s.bufptr[n:] + s.mu.Unlock() + atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n)) + return n, nil + } + + if size := s.kcp.PeekSize(); size > 0 { // peek data size from kcp + if len(b) >= size { // receive data into 'b' directly + s.kcp.Recv(b) + s.mu.Unlock() + atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(size)) + return size, nil + } + + // if necessary resize the stream buffer to guarantee a sufficient buffer space + if cap(s.recvbuf) < size { + s.recvbuf = make([]byte, size) + } + + // resize the length of recvbuf to correspond to data size + s.recvbuf = s.recvbuf[:size] + s.kcp.Recv(s.recvbuf) + n = copy(b, s.recvbuf) // copy to 'b' + s.bufptr = s.recvbuf[n:] // pointer update + s.mu.Unlock() + atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n)) + return n, nil + } + + // deadline for current reading operation + var timeout *time.Timer + var c <-chan time.Time + if !s.rd.IsZero() { + if time.Now().After(s.rd) { + s.mu.Unlock() + return 0, errors.WithStack(errTimeout) + } + + delay := time.Until(s.rd) + timeout = time.NewTimer(delay) + c = timeout.C + } + s.mu.Unlock() + + // wait for read event or timeout or error + select { + case <-s.chReadEvent: + if timeout != nil { + timeout.Stop() + } + case <-c: + return 0, errors.WithStack(errTimeout) + case <-s.chSocketReadError: + return 0, s.socketReadError.Load().(error) + case <-s.die: + return 0, errors.WithStack(io.ErrClosedPipe) + } + } +} + +// Write implements net.Conn +func (s *UDPSession) Write(b []byte) (n int, err error) { return s.WriteBuffers([][]byte{b}) } + +// WriteBuffers write a vector of byte slices to the underlying connection +func (s *UDPSession) WriteBuffers(v [][]byte) (n int, err error) { + for { + select { + case <-s.chSocketWriteError: + return 0, s.socketWriteError.Load().(error) + case <-s.die: + return 0, errors.WithStack(io.ErrClosedPipe) + default: + } + + s.mu.Lock() + + // make sure write do not overflow the max sliding window on both side + waitsnd := s.kcp.WaitSnd() + if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { + for _, b := range v { + n += len(b) + for { + if len(b) <= int(s.kcp.mss) { + s.kcp.Send(b) + break + } else { + s.kcp.Send(b[:s.kcp.mss]) + b = b[s.kcp.mss:] + } + } + } + + waitsnd = s.kcp.WaitSnd() + if waitsnd >= int(s.kcp.snd_wnd) || waitsnd >= int(s.kcp.rmt_wnd) || !s.writeDelay { + s.kcp.flush(false) + s.uncork() + } + s.mu.Unlock() + atomic.AddUint64(&DefaultSnmp.BytesSent, uint64(n)) + return n, nil + } + + var timeout *time.Timer + var c <-chan time.Time + if !s.wd.IsZero() { + if time.Now().After(s.wd) { + s.mu.Unlock() + return 0, errors.WithStack(errTimeout) + } + delay := time.Until(s.wd) + timeout = time.NewTimer(delay) + c = timeout.C + } + s.mu.Unlock() + + select { + case <-s.chWriteEvent: + if timeout != nil { + timeout.Stop() + } + case <-c: + return 0, errors.WithStack(errTimeout) + case <-s.chSocketWriteError: + return 0, s.socketWriteError.Load().(error) + case <-s.die: + return 0, errors.WithStack(io.ErrClosedPipe) + } + } +} + +// uncork sends data in txqueue if there is any +func (s *UDPSession) uncork() { + if len(s.txqueue) > 0 { + s.tx(s.txqueue) + // recycle + for k := range s.txqueue { + xmitBuf.Put(s.txqueue[k].Buffers[0]) + s.txqueue[k].Buffers = nil + } + s.txqueue = s.txqueue[:0] + } +} + +// Close closes the connection. +func (s *UDPSession) Close() error { + var once bool + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + + if once { + atomic.AddUint64(&DefaultSnmp.CurrEstab, ^uint64(0)) + + // try best to send all queued messages + s.mu.Lock() + s.kcp.flush(false) + s.uncork() + // release pending segments + s.kcp.ReleaseTX() + if s.fecDecoder != nil { + s.fecDecoder.release() + } + s.mu.Unlock() + + if s.l != nil { // belongs to listener + s.l.closeSession(s.remote) + return nil + } else if s.ownConn { // client socket close + return s.conn.Close() + } else { + return nil + } + } else { + return errors.WithStack(io.ErrClosedPipe) + } +} + +// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. +func (s *UDPSession) LocalAddr() net.Addr { return s.conn.LocalAddr() } + +// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. +func (s *UDPSession) RemoteAddr() net.Addr { return s.remote } + +// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. +func (s *UDPSession) SetDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.rd = t + s.wd = t + s.notifyReadEvent() + s.notifyWriteEvent() + return nil +} + +// SetReadDeadline implements the Conn SetReadDeadline method. +func (s *UDPSession) SetReadDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.rd = t + s.notifyReadEvent() + return nil +} + +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (s *UDPSession) SetWriteDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.wd = t + s.notifyWriteEvent() + return nil +} + +// SetWriteDelay delays write for bulk transfer until the next update interval +func (s *UDPSession) SetWriteDelay(delay bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.writeDelay = delay +} + +// SetWindowSize set maximum window size +func (s *UDPSession) SetWindowSize(sndwnd, rcvwnd int) { + s.mu.Lock() + defer s.mu.Unlock() + s.kcp.WndSize(sndwnd, rcvwnd) +} + +// SetMtu sets the maximum transmission unit(not including UDP header) +func (s *UDPSession) SetMtu(mtu int) bool { + if mtu > mtuLimit { + return false + } + + s.mu.Lock() + defer s.mu.Unlock() + s.kcp.SetMtu(mtu) + return true +} + +// SetStreamMode toggles the stream mode on/off +func (s *UDPSession) SetStreamMode(enable bool) { + s.mu.Lock() + defer s.mu.Unlock() + if enable { + s.kcp.stream = 1 + } else { + s.kcp.stream = 0 + } +} + +// SetACKNoDelay changes ack flush option, set true to flush ack immediately, +func (s *UDPSession) SetACKNoDelay(nodelay bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.ackNoDelay = nodelay +} + +// (deprecated) +// +// SetDUP duplicates udp packets for kcp output. +func (s *UDPSession) SetDUP(dup int) { + s.mu.Lock() + defer s.mu.Unlock() + s.dup = dup +} + +// SetNoDelay calls nodelay() of kcp +// https://github.com/skywind3000/kcp/blob/master/README.en.md#protocol-configuration +func (s *UDPSession) SetNoDelay(nodelay, interval, resend, nc int) { + s.mu.Lock() + defer s.mu.Unlock() + s.kcp.NoDelay(nodelay, interval, resend, nc) +} + +// SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header. +// +// if the underlying connection has implemented `func SetDSCP(int) error`, SetDSCP() will invoke +// this function instead. +// +// It has no effect if it's accepted from Listener. +func (s *UDPSession) SetDSCP(dscp int) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.l != nil { + return errInvalidOperation + } + + // interface enabled + if ts, ok := s.conn.(setDSCP); ok { + return ts.SetDSCP(dscp) + } + + if nc, ok := s.conn.(net.Conn); ok { + var succeed bool + if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err == nil { + succeed = true + } + if err := ipv6.NewConn(nc).SetTrafficClass(dscp); err == nil { + succeed = true + } + + if succeed { + return nil + } + } + return errInvalidOperation +} + +// SetReadBuffer sets the socket read buffer, no effect if it's accepted from Listener +func (s *UDPSession) SetReadBuffer(bytes int) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.l == nil { + if nc, ok := s.conn.(setReadBuffer); ok { + return nc.SetReadBuffer(bytes) + } + } + return errInvalidOperation +} + +// SetWriteBuffer sets the socket write buffer, no effect if it's accepted from Listener +func (s *UDPSession) SetWriteBuffer(bytes int) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.l == nil { + if nc, ok := s.conn.(setWriteBuffer); ok { + return nc.SetWriteBuffer(bytes) + } + } + return errInvalidOperation +} + +// post-processing for sending a packet from kcp core +// steps: +// 1. FEC packet generation +// 2. CRC32 integrity +// 3. Encryption +// 4. TxQueue +func (s *UDPSession) output(buf []byte) { + var ecc [][]byte + + // 1. FEC encoding + if s.fecEncoder != nil { + ecc = s.fecEncoder.encode(buf) + } + + // 2&3. crc32 & encryption + if s.block != nil { + s.nonce.Fill(buf[:nonceSize]) + checksum := crc32.ChecksumIEEE(buf[cryptHeaderSize:]) + binary.LittleEndian.PutUint32(buf[nonceSize:], checksum) + s.block.Encrypt(buf, buf) + + for k := range ecc { + s.nonce.Fill(ecc[k][:nonceSize]) + checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:]) + binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum) + s.block.Encrypt(ecc[k], ecc[k]) + } + } + + // 4. TxQueue + var msg ipv4.Message + for i := 0; i < s.dup+1; i++ { + bts := xmitBuf.Get().([]byte)[:len(buf)] + copy(bts, buf) + msg.Buffers = [][]byte{bts} + msg.Addr = s.remote + s.txqueue = append(s.txqueue, msg) + } + + for k := range ecc { + bts := xmitBuf.Get().([]byte)[:len(ecc[k])] + copy(bts, ecc[k]) + msg.Buffers = [][]byte{bts} + msg.Addr = s.remote + s.txqueue = append(s.txqueue, msg) + } +} + +// sess update to trigger protocol +func (s *UDPSession) update() { + select { + case <-s.die: + default: + s.mu.Lock() + interval := s.kcp.flush(false) + waitsnd := s.kcp.WaitSnd() + if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { + s.notifyWriteEvent() + } + s.uncork() + s.mu.Unlock() + // self-synchronized timed scheduling + SystemTimedSched.Put(s.update, time.Now().Add(time.Duration(interval)*time.Millisecond)) + } +} + +// GetConv gets conversation id of a session +func (s *UDPSession) GetConv() uint32 { return s.kcp.conv } + +// GetRTO gets current rto of the session +func (s *UDPSession) GetRTO() uint32 { + s.mu.Lock() + defer s.mu.Unlock() + return s.kcp.rx_rto +} + +// GetSRTT gets current srtt of the session +func (s *UDPSession) GetSRTT() int32 { + s.mu.Lock() + defer s.mu.Unlock() + return s.kcp.rx_srtt +} + +// GetRTTVar gets current rtt variance of the session +func (s *UDPSession) GetSRTTVar() int32 { + s.mu.Lock() + defer s.mu.Unlock() + return s.kcp.rx_rttvar +} + +func (s *UDPSession) notifyReadEvent() { + select { + case s.chReadEvent <- struct{}{}: + default: + } +} + +func (s *UDPSession) notifyWriteEvent() { + select { + case s.chWriteEvent <- struct{}{}: + default: + } +} + +func (s *UDPSession) notifyReadError(err error) { + s.socketReadErrorOnce.Do(func() { + s.socketReadError.Store(err) + close(s.chSocketReadError) + }) +} + +func (s *UDPSession) notifyWriteError(err error) { + s.socketWriteErrorOnce.Do(func() { + s.socketWriteError.Store(err) + close(s.chSocketWriteError) + }) +} + +// packet input stage +func (s *UDPSession) packetInput(data []byte) { + decrypted := false + if s.block != nil && len(data) >= cryptHeaderSize { + s.block.Decrypt(data, data) + data = data[nonceSize:] + checksum := crc32.ChecksumIEEE(data[crcSize:]) + if checksum == binary.LittleEndian.Uint32(data) { + data = data[crcSize:] + decrypted = true + } else { + atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) + } + } else if s.block == nil { + decrypted = true + } + + if decrypted && len(data) >= IKCP_OVERHEAD { + s.kcpInput(data) + } +} + +func (s *UDPSession) kcpInput(data []byte) { + var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64 + + fecFlag := binary.LittleEndian.Uint16(data[4:]) + if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2 + if len(data) >= fecHeaderSizePlus2 { + f := fecPacket(data) + if f.flag() == typeParity { + fecParityShards++ + } + + // lock + s.mu.Lock() + // if fecDecoder is not initialized, create one with default parameter + if s.fecDecoder == nil { + s.fecDecoder = newFECDecoder(1, 1) + } + recovers := s.fecDecoder.decode(f) + if f.flag() == typeData { + if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 { + kcpInErrors++ + } + } + + for _, r := range recovers { + if len(r) >= 2 { // must be larger than 2bytes + sz := binary.LittleEndian.Uint16(r) + if int(sz) <= len(r) && sz >= 2 { + if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 { + fecRecovered++ + } else { + kcpInErrors++ + } + } else { + fecErrs++ + } + } else { + fecErrs++ + } + // recycle the recovers + xmitBuf.Put(r) + } + + // to notify the readers to receive the data + if n := s.kcp.PeekSize(); n > 0 { + s.notifyReadEvent() + } + // to notify the writers + waitsnd := s.kcp.WaitSnd() + if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { + s.notifyWriteEvent() + } + + s.uncork() + s.mu.Unlock() + } else { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + } + } else { + s.mu.Lock() + if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 { + kcpInErrors++ + } + if n := s.kcp.PeekSize(); n > 0 { + s.notifyReadEvent() + } + waitsnd := s.kcp.WaitSnd() + if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { + s.notifyWriteEvent() + } + s.uncork() + s.mu.Unlock() + } + + atomic.AddUint64(&DefaultSnmp.InPkts, 1) + atomic.AddUint64(&DefaultSnmp.InBytes, uint64(len(data))) + if fecParityShards > 0 { + atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards) + } + if kcpInErrors > 0 { + atomic.AddUint64(&DefaultSnmp.KCPInErrors, kcpInErrors) + } + if fecErrs > 0 { + atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs) + } + if fecRecovered > 0 { + atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered) + } + +} + +type ( + // Listener defines a server which will be waiting to accept incoming connections + Listener struct { + block BlockCrypt // block encryption + dataShards int // FEC data shard + parityShards int // FEC parity shard + conn net.PacketConn // the underlying packet connection + ownConn bool // true if we created conn internally, false if provided by caller + + sessions map[string]*UDPSession // all sessions accepted by this Listener + sessionLock sync.RWMutex + chAccepts chan *UDPSession // Listen() backlog + chSessionClosed chan net.Addr // session close queue + + die chan struct{} // notify the listener has closed + dieOnce sync.Once + + // socket error handling + socketReadError atomic.Value + chSocketReadError chan struct{} + socketReadErrorOnce sync.Once + + rd atomic.Value // read deadline for Accept() + } +) + +// packet input stage +func (l *Listener) packetInput(data []byte, addr net.Addr) { + decrypted := false + if l.block != nil && len(data) >= cryptHeaderSize { + l.block.Decrypt(data, data) + data = data[nonceSize:] + checksum := crc32.ChecksumIEEE(data[crcSize:]) + if checksum == binary.LittleEndian.Uint32(data) { + data = data[crcSize:] + decrypted = true + } else { + atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) + } + } else if l.block == nil { + decrypted = true + } + + if decrypted && len(data) >= IKCP_OVERHEAD { + l.sessionLock.RLock() + s, ok := l.sessions[addr.String()] + l.sessionLock.RUnlock() + + var conv, sn uint32 + convRecovered := false + fecFlag := binary.LittleEndian.Uint16(data[4:]) + if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2 + // packet with FEC + if fecFlag == typeData && len(data) >= fecHeaderSizePlus2+IKCP_OVERHEAD { + conv = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2:]) + sn = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2+IKCP_SN_OFFSET:]) + convRecovered = true + } + } else { + // packet without FEC + conv = binary.LittleEndian.Uint32(data) + sn = binary.LittleEndian.Uint32(data[IKCP_SN_OFFSET:]) + convRecovered = true + } + + if ok { // existing connection + if !convRecovered || conv == s.kcp.conv { // parity data or valid conversation + s.kcpInput(data) + } else if sn == 0 { // should replace current connection + s.Close() + s = nil + } + } + + if s == nil && convRecovered { // new session + if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue + s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, false, addr, l.block) + s.kcpInput(data) + l.sessionLock.Lock() + l.sessions[addr.String()] = s + l.sessionLock.Unlock() + l.chAccepts <- s + } + } + } +} + +func (l *Listener) notifyReadError(err error) { + l.socketReadErrorOnce.Do(func() { + l.socketReadError.Store(err) + close(l.chSocketReadError) + + // propagate read error to all sessions + l.sessionLock.RLock() + for _, s := range l.sessions { + s.notifyReadError(err) + } + l.sessionLock.RUnlock() + }) +} + +// SetReadBuffer sets the socket read buffer for the Listener +func (l *Listener) SetReadBuffer(bytes int) error { + if nc, ok := l.conn.(setReadBuffer); ok { + return nc.SetReadBuffer(bytes) + } + return errInvalidOperation +} + +// SetWriteBuffer sets the socket write buffer for the Listener +func (l *Listener) SetWriteBuffer(bytes int) error { + if nc, ok := l.conn.(setWriteBuffer); ok { + return nc.SetWriteBuffer(bytes) + } + return errInvalidOperation +} + +// SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header. +// +// if the underlying connection has implemented `func SetDSCP(int) error`, SetDSCP() will invoke +// this function instead. +func (l *Listener) SetDSCP(dscp int) error { + // interface enabled + if ts, ok := l.conn.(setDSCP); ok { + return ts.SetDSCP(dscp) + } + + if nc, ok := l.conn.(net.Conn); ok { + var succeed bool + if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err == nil { + succeed = true + } + if err := ipv6.NewConn(nc).SetTrafficClass(dscp); err == nil { + succeed = true + } + + if succeed { + return nil + } + } + return errInvalidOperation +} + +// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. +func (l *Listener) Accept() (net.Conn, error) { + return l.AcceptKCP() +} + +// AcceptKCP accepts a KCP connection +func (l *Listener) AcceptKCP() (*UDPSession, error) { + var timeout <-chan time.Time + if tdeadline, ok := l.rd.Load().(time.Time); ok && !tdeadline.IsZero() { + timeout = time.After(time.Until(tdeadline)) + } + + select { + case <-timeout: + return nil, errors.WithStack(errTimeout) + case c := <-l.chAccepts: + return c, nil + case <-l.chSocketReadError: + return nil, l.socketReadError.Load().(error) + case <-l.die: + return nil, errors.WithStack(io.ErrClosedPipe) + } +} + +// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. +func (l *Listener) SetDeadline(t time.Time) error { + l.SetReadDeadline(t) + l.SetWriteDeadline(t) + return nil +} + +// SetReadDeadline implements the Conn SetReadDeadline method. +func (l *Listener) SetReadDeadline(t time.Time) error { + l.rd.Store(t) + return nil +} + +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (l *Listener) SetWriteDeadline(t time.Time) error { return errInvalidOperation } + +// Close stops listening on the UDP address, and closes the socket +func (l *Listener) Close() error { + var once bool + l.dieOnce.Do(func() { + close(l.die) + once = true + }) + + var err error + if once { + if l.ownConn { + err = l.conn.Close() + } + } else { + err = errors.WithStack(io.ErrClosedPipe) + } + return err +} + +// closeSession notify the listener that a session has closed +func (l *Listener) closeSession(remote net.Addr) (ret bool) { + l.sessionLock.Lock() + defer l.sessionLock.Unlock() + if _, ok := l.sessions[remote.String()]; ok { + delete(l.sessions, remote.String()) + return true + } + return false +} + +// Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it. +func (l *Listener) Addr() net.Addr { return l.conn.LocalAddr() } + +// Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp", +func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) } + +// ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption. +// +// 'block' is the block encryption algorithm to encrypt packets. +// +// 'dataShards', 'parityShards' specify how many parity packets will be generated following the data packets. +// +// Check https://github.com/klauspost/reedsolomon for details +func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) { + udpaddr, err := net.ResolveUDPAddr("udp", laddr) + if err != nil { + return nil, errors.WithStack(err) + } + conn, err := net.ListenUDP("udp", udpaddr) + if err != nil { + return nil, errors.WithStack(err) + } + + return serveConn(block, dataShards, parityShards, conn, true) +} + +// ServeConn serves KCP protocol for a single packet connection. +func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*Listener, error) { + return serveConn(block, dataShards, parityShards, conn, false) +} + +func serveConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn, ownConn bool) (*Listener, error) { + l := new(Listener) + l.conn = conn + l.ownConn = ownConn + l.sessions = make(map[string]*UDPSession) + l.chAccepts = make(chan *UDPSession, acceptBacklog) + l.chSessionClosed = make(chan net.Addr) + l.die = make(chan struct{}) + l.dataShards = dataShards + l.parityShards = parityShards + l.block = block + l.chSocketReadError = make(chan struct{}) + go l.monitor() + return l, nil +} + +// Dial connects to the remote address "raddr" on the network "udp" without encryption and FEC +func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) } + +// DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption +// +// 'block' is the block encryption algorithm to encrypt packets. +// +// 'dataShards', 'parityShards' specify how many parity packets will be generated following the data packets. +// +// Check https://github.com/klauspost/reedsolomon for details +func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) { + // network type detection + udpaddr, err := net.ResolveUDPAddr("udp", raddr) + if err != nil { + return nil, errors.WithStack(err) + } + network := "udp4" + if udpaddr.IP.To4() == nil { + network = "udp" + } + + conn, err := net.ListenUDP(network, nil) + if err != nil { + return nil, errors.WithStack(err) + } + + var convid uint32 + binary.Read(rand.Reader, binary.LittleEndian, &convid) + return newUDPSession(convid, dataShards, parityShards, nil, conn, true, udpaddr, block), nil +} + +// NewConn3 establishes a session and talks KCP protocol over a packet connection. +func NewConn3(convid uint32, raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { + return newUDPSession(convid, dataShards, parityShards, nil, conn, false, raddr, block), nil +} + +// NewConn2 establishes a session and talks KCP protocol over a packet connection. +func NewConn2(raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { + var convid uint32 + binary.Read(rand.Reader, binary.LittleEndian, &convid) + return NewConn3(convid, raddr, block, dataShards, parityShards, conn) +} + +// NewConn establishes a session and talks KCP protocol over a packet connection. +func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { + udpaddr, err := net.ResolveUDPAddr("udp", raddr) + if err != nil { + return nil, errors.WithStack(err) + } + return NewConn2(udpaddr, block, dataShards, parityShards, conn) +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/sess_test.go b/vendor/github.com/xtaci/kcp-go/v5/sess_test.go new file mode 100644 index 00000000..fbe3ad1a --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/sess_test.go @@ -0,0 +1,702 @@ +package kcp + +import ( + "crypto/sha1" + "fmt" + "io" + "log" + "net" + "net/http" + _ "net/http/pprof" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/crypto/pbkdf2" +) + +var baseport = uint32(10000) +var key = []byte("testkey") +var pass = pbkdf2.Key(key, []byte("testsalt"), 4096, 32, sha1.New) + +func init() { + go func() { + log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) + }() + + log.Println("beginning tests, encryption:salsa20, fec:10/3") +} + +func dialEcho(port int) (*UDPSession, error) { + //block, _ := NewNoneBlockCrypt(pass) + //block, _ := NewSimpleXORBlockCrypt(pass) + //block, _ := NewTEABlockCrypt(pass[:16]) + //block, _ := NewAESBlockCrypt(pass) + block, _ := NewSalsa20BlockCrypt(pass) + sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3) + if err != nil { + panic(err) + } + + sess.SetStreamMode(true) + sess.SetStreamMode(false) + sess.SetStreamMode(true) + sess.SetWindowSize(1024, 1024) + sess.SetReadBuffer(16 * 1024 * 1024) + sess.SetWriteBuffer(16 * 1024 * 1024) + sess.SetStreamMode(true) + sess.SetNoDelay(1, 10, 2, 1) + sess.SetMtu(1400) + sess.SetMtu(1600) + sess.SetMtu(1400) + sess.SetACKNoDelay(true) + sess.SetACKNoDelay(false) + sess.SetDeadline(time.Now().Add(time.Minute)) + return sess, err +} + +func dialSink(port int) (*UDPSession, error) { + sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0) + if err != nil { + panic(err) + } + + sess.SetStreamMode(true) + sess.SetWindowSize(1024, 1024) + sess.SetReadBuffer(16 * 1024 * 1024) + sess.SetWriteBuffer(16 * 1024 * 1024) + sess.SetStreamMode(true) + sess.SetNoDelay(1, 10, 2, 1) + sess.SetMtu(1400) + sess.SetACKNoDelay(false) + sess.SetDeadline(time.Now().Add(time.Minute)) + return sess, err +} + +func dialTinyBufferEcho(port int) (*UDPSession, error) { + //block, _ := NewNoneBlockCrypt(pass) + //block, _ := NewSimpleXORBlockCrypt(pass) + //block, _ := NewTEABlockCrypt(pass[:16]) + //block, _ := NewAESBlockCrypt(pass) + block, _ := NewSalsa20BlockCrypt(pass) + sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3) + if err != nil { + panic(err) + } + return sess, err +} + +////////////////////////// +func listenEcho(port int) (net.Listener, error) { + //block, _ := NewNoneBlockCrypt(pass) + //block, _ := NewSimpleXORBlockCrypt(pass) + //block, _ := NewTEABlockCrypt(pass[:16]) + //block, _ := NewAESBlockCrypt(pass) + block, _ := NewSalsa20BlockCrypt(pass) + return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 0) +} +func listenTinyBufferEcho(port int) (net.Listener, error) { + //block, _ := NewNoneBlockCrypt(pass) + //block, _ := NewSimpleXORBlockCrypt(pass) + //block, _ := NewTEABlockCrypt(pass[:16]) + //block, _ := NewAESBlockCrypt(pass) + block, _ := NewSalsa20BlockCrypt(pass) + return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3) +} + +func listenSink(port int) (net.Listener, error) { + return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0) +} + +func echoServer(port int) net.Listener { + l, err := listenEcho(port) + if err != nil { + panic(err) + } + + go func() { + kcplistener := l.(*Listener) + kcplistener.SetReadBuffer(4 * 1024 * 1024) + kcplistener.SetWriteBuffer(4 * 1024 * 1024) + kcplistener.SetDSCP(46) + for { + s, err := l.Accept() + if err != nil { + return + } + + // coverage test + s.(*UDPSession).SetReadBuffer(4 * 1024 * 1024) + s.(*UDPSession).SetWriteBuffer(4 * 1024 * 1024) + go handleEcho(s.(*UDPSession)) + } + }() + + return l +} + +func sinkServer(port int) net.Listener { + l, err := listenSink(port) + if err != nil { + panic(err) + } + + go func() { + kcplistener := l.(*Listener) + kcplistener.SetReadBuffer(4 * 1024 * 1024) + kcplistener.SetWriteBuffer(4 * 1024 * 1024) + kcplistener.SetDSCP(46) + for { + s, err := l.Accept() + if err != nil { + return + } + + go handleSink(s.(*UDPSession)) + } + }() + + return l +} + +func tinyBufferEchoServer(port int) net.Listener { + l, err := listenTinyBufferEcho(port) + if err != nil { + panic(err) + } + + go func() { + for { + s, err := l.Accept() + if err != nil { + return + } + go handleTinyBufferEcho(s.(*UDPSession)) + } + }() + return l +} + +/////////////////////////// + +func handleEcho(conn *UDPSession) { + conn.SetStreamMode(true) + conn.SetWindowSize(4096, 4096) + conn.SetNoDelay(1, 10, 2, 1) + conn.SetDSCP(46) + conn.SetMtu(1400) + conn.SetACKNoDelay(false) + conn.SetReadDeadline(time.Now().Add(time.Hour)) + conn.SetWriteDeadline(time.Now().Add(time.Hour)) + buf := make([]byte, 65536) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) + } +} + +func handleSink(conn *UDPSession) { + conn.SetStreamMode(true) + conn.SetWindowSize(4096, 4096) + conn.SetNoDelay(1, 10, 2, 1) + conn.SetDSCP(46) + conn.SetMtu(1400) + conn.SetACKNoDelay(false) + conn.SetReadDeadline(time.Now().Add(time.Hour)) + conn.SetWriteDeadline(time.Now().Add(time.Hour)) + buf := make([]byte, 65536) + for { + _, err := conn.Read(buf) + if err != nil { + return + } + } +} + +func handleTinyBufferEcho(conn *UDPSession) { + conn.SetStreamMode(true) + buf := make([]byte, 2) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) + } +} + +/////////////////////////// + +func TestTimeout(t *testing.T) { + port := int(atomic.AddUint32(&baseport, 1)) + l := echoServer(port) + defer l.Close() + + cli, err := dialEcho(port) + if err != nil { + panic(err) + } + buf := make([]byte, 10) + + //timeout + cli.SetDeadline(time.Now().Add(time.Second)) + <-time.After(2 * time.Second) + n, err := cli.Read(buf) + if n != 0 || err == nil { + t.Fail() + } + cli.Close() +} + +func TestSendRecv(t *testing.T) { + port := int(atomic.AddUint32(&baseport, 1)) + l := echoServer(port) + defer l.Close() + + cli, err := dialEcho(port) + if err != nil { + panic(err) + } + cli.SetWriteDelay(true) + cli.SetDUP(1) + const N = 100 + buf := make([]byte, 10) + for i := 0; i < N; i++ { + msg := fmt.Sprintf("hello%v", i) + cli.Write([]byte(msg)) + if n, err := cli.Read(buf); err == nil { + if string(buf[:n]) != msg { + t.Fail() + } + } else { + panic(err) + } + } + cli.Close() +} + +func TestSendVector(t *testing.T) { + port := int(atomic.AddUint32(&baseport, 1)) + l := echoServer(port) + defer l.Close() + + cli, err := dialEcho(port) + if err != nil { + panic(err) + } + cli.SetWriteDelay(false) + const N = 100 + buf := make([]byte, 20) + v := make([][]byte, 2) + for i := 0; i < N; i++ { + v[0] = []byte(fmt.Sprintf("hello%v", i)) + v[1] = []byte(fmt.Sprintf("world%v", i)) + msg := fmt.Sprintf("hello%vworld%v", i, i) + cli.WriteBuffers(v) + if n, err := cli.Read(buf); err == nil { + if string(buf[:n]) != msg { + t.Error(string(buf[:n]), msg) + } + } else { + panic(err) + } + } + cli.Close() +} + +func TestTinyBufferReceiver(t *testing.T) { + port := int(atomic.AddUint32(&baseport, 1)) + l := tinyBufferEchoServer(port) + defer l.Close() + + cli, err := dialTinyBufferEcho(port) + if err != nil { + panic(err) + } + const N = 100 + snd := byte(0) + fillBuffer := func(buf []byte) { + for i := 0; i < len(buf); i++ { + buf[i] = snd + snd++ + } + } + + rcv := byte(0) + check := func(buf []byte) bool { + for i := 0; i < len(buf); i++ { + if buf[i] != rcv { + return false + } + rcv++ + } + return true + } + sndbuf := make([]byte, 7) + rcvbuf := make([]byte, 7) + for i := 0; i < N; i++ { + fillBuffer(sndbuf) + cli.Write(sndbuf) + if n, err := io.ReadFull(cli, rcvbuf); err == nil { + if !check(rcvbuf[:n]) { + t.Fail() + } + } else { + panic(err) + } + } + cli.Close() +} + +func TestClose(t *testing.T) { + var n int + var err error + + port := int(atomic.AddUint32(&baseport, 1)) + l := echoServer(port) + defer l.Close() + + cli, err := dialEcho(port) + if err != nil { + panic(err) + } + + // double close + cli.Close() + if cli.Close() == nil { + t.Fatal("double close misbehavior") + } + + // write after close + buf := make([]byte, 10) + n, err = cli.Write(buf) + if n != 0 || err == nil { + t.Fatal("write after close misbehavior") + } + + // write, close, read, read + cli, err = dialEcho(port) + if err != nil { + panic(err) + } + if n, err = cli.Write(buf); err != nil { + t.Fatal("write misbehavior") + } + + // wait until data arrival + time.Sleep(2 * time.Second) + // drain + cli.Close() + n, err = io.ReadFull(cli, buf) + if err != nil { + t.Fatal("closed conn drain bytes failed", err, n) + } + + // after drain, read should return error + n, err = cli.Read(buf) + if n != 0 || err == nil { + t.Fatal("write->close->drain->read misbehavior", err, n) + } + cli.Close() +} + +func TestParallel1024CLIENT_64BMSG_64CNT(t *testing.T) { + port := int(atomic.AddUint32(&baseport, 1)) + l := echoServer(port) + defer l.Close() + + var wg sync.WaitGroup + wg.Add(1024) + for i := 0; i < 1024; i++ { + go parallel_client(&wg, port) + } + wg.Wait() +} + +func parallel_client(wg *sync.WaitGroup, port int) (err error) { + cli, err := dialEcho(port) + if err != nil { + panic(err) + } + + err = echo_tester(cli, 64, 64) + cli.Close() + wg.Done() + return +} + +func BenchmarkEchoSpeed4K(b *testing.B) { + speedclient(b, 4096) +} + +func BenchmarkEchoSpeed64K(b *testing.B) { + speedclient(b, 65536) +} + +func BenchmarkEchoSpeed512K(b *testing.B) { + speedclient(b, 524288) +} + +func BenchmarkEchoSpeed1M(b *testing.B) { + speedclient(b, 1048576) +} + +func speedclient(b *testing.B, nbytes int) { + port := int(atomic.AddUint32(&baseport, 1)) + l := echoServer(port) + defer l.Close() + + b.ReportAllocs() + cli, err := dialEcho(port) + if err != nil { + panic(err) + } + + if err := echo_tester(cli, nbytes, b.N); err != nil { + b.Fail() + } + b.SetBytes(int64(nbytes)) + cli.Close() +} + +func BenchmarkSinkSpeed4K(b *testing.B) { + sinkclient(b, 4096) +} + +func BenchmarkSinkSpeed64K(b *testing.B) { + sinkclient(b, 65536) +} + +func BenchmarkSinkSpeed256K(b *testing.B) { + sinkclient(b, 524288) +} + +func BenchmarkSinkSpeed1M(b *testing.B) { + sinkclient(b, 1048576) +} + +func sinkclient(b *testing.B, nbytes int) { + port := int(atomic.AddUint32(&baseport, 1)) + l := sinkServer(port) + defer l.Close() + + b.ReportAllocs() + cli, err := dialSink(port) + if err != nil { + panic(err) + } + + sink_tester(cli, nbytes, b.N) + b.SetBytes(int64(nbytes)) + cli.Close() +} + +func echo_tester(cli net.Conn, msglen, msgcount int) error { + buf := make([]byte, msglen) + for i := 0; i < msgcount; i++ { + // send packet + if _, err := cli.Write(buf); err != nil { + return err + } + + // receive packet + nrecv := 0 + for { + n, err := cli.Read(buf) + if err != nil { + return err + } else { + nrecv += n + if nrecv == msglen { + break + } + } + } + } + return nil +} + +func sink_tester(cli *UDPSession, msglen, msgcount int) error { + // sender + buf := make([]byte, msglen) + for i := 0; i < msgcount; i++ { + if _, err := cli.Write(buf); err != nil { + return err + } + } + return nil +} + +func TestSNMP(t *testing.T) { + t.Log(DefaultSnmp.Copy()) + t.Log(DefaultSnmp.Header()) + t.Log(DefaultSnmp.ToSlice()) + DefaultSnmp.Reset() + t.Log(DefaultSnmp.ToSlice()) +} + +func TestListenerClose(t *testing.T) { + port := int(atomic.AddUint32(&baseport, 1)) + l, err := ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 10, 3) + if err != nil { + t.Fail() + } + l.SetReadDeadline(time.Now().Add(time.Second)) + l.SetWriteDeadline(time.Now().Add(time.Second)) + l.SetDeadline(time.Now().Add(time.Second)) + time.Sleep(2 * time.Second) + if _, err := l.Accept(); err == nil { + t.Fail() + } + + l.Close() + fakeaddr, _ := net.ResolveUDPAddr("udp6", "127.0.0.1:1111") + if l.closeSession(fakeaddr) { + t.Fail() + } +} + +// A wrapper for net.PacketConn that remembers when Close has been called. +type closedFlagPacketConn struct { + net.PacketConn + Closed bool +} + +func (c *closedFlagPacketConn) Close() error { + c.Closed = true + return c.PacketConn.Close() +} + +func newClosedFlagPacketConn(c net.PacketConn) *closedFlagPacketConn { + return &closedFlagPacketConn{c, false} +} + +// Listener should close a net.PacketConn that it created. +// https://github.com/xtaci/kcp-go/issues/165 +func TestListenerOwnedPacketConn(t *testing.T) { + // ListenWithOptions creates its own net.PacketConn. + l, err := ListenWithOptions("127.0.0.1:0", nil, 0, 0) + if err != nil { + panic(err) + } + defer l.Close() + // Replace the internal net.PacketConn with one that remembers when it + // has been closed. + pconn := newClosedFlagPacketConn(l.conn) + l.conn = pconn + + if pconn.Closed { + t.Fatal("owned PacketConn closed before Listener.Close()") + } + + err = l.Close() + if err != nil { + panic(err) + } + + if !pconn.Closed { + t.Fatal("owned PacketConn not closed after Listener.Close()") + } +} + +// Listener should not close a net.PacketConn that it did not create. +// https://github.com/xtaci/kcp-go/issues/165 +func TestListenerNonOwnedPacketConn(t *testing.T) { + // Create a net.PacketConn not owned by the Listener. + c, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + panic(err) + } + defer c.Close() + // Make it remember when it has been closed. + pconn := newClosedFlagPacketConn(c) + + l, err := ServeConn(nil, 0, 0, pconn) + if err != nil { + panic(err) + } + defer l.Close() + + if pconn.Closed { + t.Fatal("non-owned PacketConn closed before Listener.Close()") + } + + err = l.Close() + if err != nil { + panic(err) + } + + if pconn.Closed { + t.Fatal("non-owned PacketConn closed after Listener.Close()") + } +} + +// UDPSession should close a net.PacketConn that it created. +// https://github.com/xtaci/kcp-go/issues/165 +func TestUDPSessionOwnedPacketConn(t *testing.T) { + l := sinkServer(0) + defer l.Close() + + // DialWithOptions creates its own net.PacketConn. + client, err := DialWithOptions(l.Addr().String(), nil, 0, 0) + if err != nil { + panic(err) + } + defer client.Close() + // Replace the internal net.PacketConn with one that remembers when it + // has been closed. + pconn := newClosedFlagPacketConn(client.conn) + client.conn = pconn + + if pconn.Closed { + t.Fatal("owned PacketConn closed before UDPSession.Close()") + } + + err = client.Close() + if err != nil { + panic(err) + } + + if !pconn.Closed { + t.Fatal("owned PacketConn not closed after UDPSession.Close()") + } +} + +// UDPSession should not close a net.PacketConn that it did not create. +// https://github.com/xtaci/kcp-go/issues/165 +func TestUDPSessionNonOwnedPacketConn(t *testing.T) { + l := sinkServer(0) + defer l.Close() + + // Create a net.PacketConn not owned by the UDPSession. + c, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + panic(err) + } + defer c.Close() + // Make it remember when it has been closed. + pconn := newClosedFlagPacketConn(c) + + client, err := NewConn2(l.Addr(), nil, 0, 0, pconn) + if err != nil { + panic(err) + } + defer client.Close() + + if pconn.Closed { + t.Fatal("non-owned PacketConn closed before UDPSession.Close()") + } + + err = client.Close() + if err != nil { + panic(err) + } + + if pconn.Closed { + t.Fatal("non-owned PacketConn closed after UDPSession.Close()") + } +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/snmp.go b/vendor/github.com/xtaci/kcp-go/v5/snmp.go new file mode 100644 index 00000000..f9618107 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/snmp.go @@ -0,0 +1,164 @@ +package kcp + +import ( + "fmt" + "sync/atomic" +) + +// Snmp defines network statistics indicator +type Snmp struct { + BytesSent uint64 // bytes sent from upper level + BytesReceived uint64 // bytes received to upper level + MaxConn uint64 // max number of connections ever reached + ActiveOpens uint64 // accumulated active open connections + PassiveOpens uint64 // accumulated passive open connections + CurrEstab uint64 // current number of established connections + InErrs uint64 // UDP read errors reported from net.PacketConn + InCsumErrors uint64 // checksum errors from CRC32 + KCPInErrors uint64 // packet iput errors reported from KCP + InPkts uint64 // incoming packets count + OutPkts uint64 // outgoing packets count + InSegs uint64 // incoming KCP segments + OutSegs uint64 // outgoing KCP segments + InBytes uint64 // UDP bytes received + OutBytes uint64 // UDP bytes sent + RetransSegs uint64 // accmulated retransmited segments + FastRetransSegs uint64 // accmulated fast retransmitted segments + EarlyRetransSegs uint64 // accmulated early retransmitted segments + LostSegs uint64 // number of segs inferred as lost + RepeatSegs uint64 // number of segs duplicated + FECRecovered uint64 // correct packets recovered from FEC + FECErrs uint64 // incorrect packets recovered from FEC + FECParityShards uint64 // FEC segments received + FECShortShards uint64 // number of data shards that's not enough for recovery +} + +func newSnmp() *Snmp { + return new(Snmp) +} + +// Header returns all field names +func (s *Snmp) Header() []string { + return []string{ + "BytesSent", + "BytesReceived", + "MaxConn", + "ActiveOpens", + "PassiveOpens", + "CurrEstab", + "InErrs", + "InCsumErrors", + "KCPInErrors", + "InPkts", + "OutPkts", + "InSegs", + "OutSegs", + "InBytes", + "OutBytes", + "RetransSegs", + "FastRetransSegs", + "EarlyRetransSegs", + "LostSegs", + "RepeatSegs", + "FECParityShards", + "FECErrs", + "FECRecovered", + "FECShortShards", + } +} + +// ToSlice returns current snmp info as slice +func (s *Snmp) ToSlice() []string { + snmp := s.Copy() + return []string{ + fmt.Sprint(snmp.BytesSent), + fmt.Sprint(snmp.BytesReceived), + fmt.Sprint(snmp.MaxConn), + fmt.Sprint(snmp.ActiveOpens), + fmt.Sprint(snmp.PassiveOpens), + fmt.Sprint(snmp.CurrEstab), + fmt.Sprint(snmp.InErrs), + fmt.Sprint(snmp.InCsumErrors), + fmt.Sprint(snmp.KCPInErrors), + fmt.Sprint(snmp.InPkts), + fmt.Sprint(snmp.OutPkts), + fmt.Sprint(snmp.InSegs), + fmt.Sprint(snmp.OutSegs), + fmt.Sprint(snmp.InBytes), + fmt.Sprint(snmp.OutBytes), + fmt.Sprint(snmp.RetransSegs), + fmt.Sprint(snmp.FastRetransSegs), + fmt.Sprint(snmp.EarlyRetransSegs), + fmt.Sprint(snmp.LostSegs), + fmt.Sprint(snmp.RepeatSegs), + fmt.Sprint(snmp.FECParityShards), + fmt.Sprint(snmp.FECErrs), + fmt.Sprint(snmp.FECRecovered), + fmt.Sprint(snmp.FECShortShards), + } +} + +// Copy make a copy of current snmp snapshot +func (s *Snmp) Copy() *Snmp { + d := newSnmp() + d.BytesSent = atomic.LoadUint64(&s.BytesSent) + d.BytesReceived = atomic.LoadUint64(&s.BytesReceived) + d.MaxConn = atomic.LoadUint64(&s.MaxConn) + d.ActiveOpens = atomic.LoadUint64(&s.ActiveOpens) + d.PassiveOpens = atomic.LoadUint64(&s.PassiveOpens) + d.CurrEstab = atomic.LoadUint64(&s.CurrEstab) + d.InErrs = atomic.LoadUint64(&s.InErrs) + d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors) + d.KCPInErrors = atomic.LoadUint64(&s.KCPInErrors) + d.InPkts = atomic.LoadUint64(&s.InPkts) + d.OutPkts = atomic.LoadUint64(&s.OutPkts) + d.InSegs = atomic.LoadUint64(&s.InSegs) + d.OutSegs = atomic.LoadUint64(&s.OutSegs) + d.InBytes = atomic.LoadUint64(&s.InBytes) + d.OutBytes = atomic.LoadUint64(&s.OutBytes) + d.RetransSegs = atomic.LoadUint64(&s.RetransSegs) + d.FastRetransSegs = atomic.LoadUint64(&s.FastRetransSegs) + d.EarlyRetransSegs = atomic.LoadUint64(&s.EarlyRetransSegs) + d.LostSegs = atomic.LoadUint64(&s.LostSegs) + d.RepeatSegs = atomic.LoadUint64(&s.RepeatSegs) + d.FECParityShards = atomic.LoadUint64(&s.FECParityShards) + d.FECErrs = atomic.LoadUint64(&s.FECErrs) + d.FECRecovered = atomic.LoadUint64(&s.FECRecovered) + d.FECShortShards = atomic.LoadUint64(&s.FECShortShards) + return d +} + +// Reset values to zero +func (s *Snmp) Reset() { + atomic.StoreUint64(&s.BytesSent, 0) + atomic.StoreUint64(&s.BytesReceived, 0) + atomic.StoreUint64(&s.MaxConn, 0) + atomic.StoreUint64(&s.ActiveOpens, 0) + atomic.StoreUint64(&s.PassiveOpens, 0) + atomic.StoreUint64(&s.CurrEstab, 0) + atomic.StoreUint64(&s.InErrs, 0) + atomic.StoreUint64(&s.InCsumErrors, 0) + atomic.StoreUint64(&s.KCPInErrors, 0) + atomic.StoreUint64(&s.InPkts, 0) + atomic.StoreUint64(&s.OutPkts, 0) + atomic.StoreUint64(&s.InSegs, 0) + atomic.StoreUint64(&s.OutSegs, 0) + atomic.StoreUint64(&s.InBytes, 0) + atomic.StoreUint64(&s.OutBytes, 0) + atomic.StoreUint64(&s.RetransSegs, 0) + atomic.StoreUint64(&s.FastRetransSegs, 0) + atomic.StoreUint64(&s.EarlyRetransSegs, 0) + atomic.StoreUint64(&s.LostSegs, 0) + atomic.StoreUint64(&s.RepeatSegs, 0) + atomic.StoreUint64(&s.FECParityShards, 0) + atomic.StoreUint64(&s.FECErrs, 0) + atomic.StoreUint64(&s.FECRecovered, 0) + atomic.StoreUint64(&s.FECShortShards, 0) +} + +// DefaultSnmp is the global KCP connection statistics collector +var DefaultSnmp *Snmp + +func init() { + DefaultSnmp = newSnmp() +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/timedsched.go b/vendor/github.com/xtaci/kcp-go/v5/timedsched.go new file mode 100644 index 00000000..2db7c206 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/timedsched.go @@ -0,0 +1,146 @@ +package kcp + +import ( + "container/heap" + "runtime" + "sync" + "time" +) + +// SystemTimedSched is the library level timed-scheduler +var SystemTimedSched *TimedSched = NewTimedSched(runtime.NumCPU()) + +type timedFunc struct { + execute func() + ts time.Time +} + +// a heap for sorted timed function +type timedFuncHeap []timedFunc + +func (h timedFuncHeap) Len() int { return len(h) } +func (h timedFuncHeap) Less(i, j int) bool { return h[i].ts.Before(h[j].ts) } +func (h timedFuncHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *timedFuncHeap) Push(x interface{}) { *h = append(*h, x.(timedFunc)) } +func (h *timedFuncHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + old[n-1].execute = nil // avoid memory leak + *h = old[0 : n-1] + return x +} + +// TimedSched represents the control struct for timed parallel scheduler +type TimedSched struct { + // prepending tasks + prependTasks []timedFunc + prependLock sync.Mutex + chPrependNotify chan struct{} + + // tasks will be distributed through chTask + chTask chan timedFunc + + dieOnce sync.Once + die chan struct{} +} + +// NewTimedSched creates a parallel-scheduler with given parallelization +func NewTimedSched(parallel int) *TimedSched { + ts := new(TimedSched) + ts.chTask = make(chan timedFunc) + ts.die = make(chan struct{}) + ts.chPrependNotify = make(chan struct{}, 1) + + for i := 0; i < parallel; i++ { + go ts.sched() + } + go ts.prepend() + return ts +} + +func (ts *TimedSched) sched() { + var tasks timedFuncHeap + timer := time.NewTimer(0) + drained := false + for { + select { + case task := <-ts.chTask: + now := time.Now() + if now.After(task.ts) { + // already delayed! execute immediately + task.execute() + } else { + heap.Push(&tasks, task) + // properly reset timer to trigger based on the top element + stopped := timer.Stop() + if !stopped && !drained { + <-timer.C + } + timer.Reset(tasks[0].ts.Sub(now)) + drained = false + } + case now := <-timer.C: + drained = true + for tasks.Len() > 0 { + if now.After(tasks[0].ts) { + heap.Pop(&tasks).(timedFunc).execute() + } else { + timer.Reset(tasks[0].ts.Sub(now)) + drained = false + break + } + } + case <-ts.die: + return + } + } +} + +func (ts *TimedSched) prepend() { + var tasks []timedFunc + for { + select { + case <-ts.chPrependNotify: + ts.prependLock.Lock() + // keep cap to reuse slice + if cap(tasks) < cap(ts.prependTasks) { + tasks = make([]timedFunc, 0, cap(ts.prependTasks)) + } + tasks = tasks[:len(ts.prependTasks)] + copy(tasks, ts.prependTasks) + for k := range ts.prependTasks { + ts.prependTasks[k].execute = nil // avoid memory leak + } + ts.prependTasks = ts.prependTasks[:0] + ts.prependLock.Unlock() + + for k := range tasks { + select { + case ts.chTask <- tasks[k]: + tasks[k].execute = nil // avoid memory leak + case <-ts.die: + return + } + } + tasks = tasks[:0] + case <-ts.die: + return + } + } +} + +// Put a function 'f' awaiting to be executed at 'deadline' +func (ts *TimedSched) Put(f func(), deadline time.Time) { + ts.prependLock.Lock() + ts.prependTasks = append(ts.prependTasks, timedFunc{f, deadline}) + ts.prependLock.Unlock() + + select { + case ts.chPrependNotify <- struct{}{}: + default: + } +} + +// Close terminates this scheduler +func (ts *TimedSched) Close() { ts.dieOnce.Do(func() { close(ts.die) }) } diff --git a/vendor/github.com/xtaci/kcp-go/v5/tx.go b/vendor/github.com/xtaci/kcp-go/v5/tx.go new file mode 100644 index 00000000..3397b82e --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/tx.go @@ -0,0 +1,24 @@ +package kcp + +import ( + "sync/atomic" + + "github.com/pkg/errors" + "golang.org/x/net/ipv4" +) + +func (s *UDPSession) defaultTx(txqueue []ipv4.Message) { + nbytes := 0 + npkts := 0 + for k := range txqueue { + if n, err := s.conn.WriteTo(txqueue[k].Buffers[0], txqueue[k].Addr); err == nil { + nbytes += n + npkts++ + } else { + s.notifyWriteError(errors.WithStack(err)) + break + } + } + atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) + atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/tx_generic.go b/vendor/github.com/xtaci/kcp-go/v5/tx_generic.go new file mode 100644 index 00000000..0b4f3494 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/tx_generic.go @@ -0,0 +1,11 @@ +// +build !linux + +package kcp + +import ( + "golang.org/x/net/ipv4" +) + +func (s *UDPSession) tx(txqueue []ipv4.Message) { + s.defaultTx(txqueue) +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/tx_linux.go b/vendor/github.com/xtaci/kcp-go/v5/tx_linux.go new file mode 100644 index 00000000..4f19df56 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/tx_linux.go @@ -0,0 +1,51 @@ +// +build linux + +package kcp + +import ( + "net" + "os" + "sync/atomic" + + "github.com/pkg/errors" + "golang.org/x/net/ipv4" +) + +func (s *UDPSession) tx(txqueue []ipv4.Message) { + // default version + if s.xconn == nil || s.xconnWriteError != nil { + s.defaultTx(txqueue) + return + } + + // x/net version + nbytes := 0 + npkts := 0 + for len(txqueue) > 0 { + if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil { + for k := range txqueue[:n] { + nbytes += len(txqueue[k].Buffers[0]) + } + npkts += n + txqueue = txqueue[n:] + } else { + // compatibility issue: + // for linux kernel<=2.6.32, support for sendmmsg is not available + // an error of type os.SyscallError will be returned + if operr, ok := err.(*net.OpError); ok { + if se, ok := operr.Err.(*os.SyscallError); ok { + if se.Syscall == "sendmmsg" { + s.xconnWriteError = se + s.defaultTx(txqueue) + return + } + } + } + s.notifyWriteError(errors.WithStack(err)) + break + } + } + + atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) + atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) +} diff --git a/vendor/github.com/xtaci/kcp-go/v5/wireshark/README.md b/vendor/github.com/xtaci/kcp-go/v5/wireshark/README.md new file mode 100644 index 00000000..de47070c --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/wireshark/README.md @@ -0,0 +1,2 @@ +## macOS + - Copy kcp_dissector.lua to /Applications/Wireshark.app/Contents/PlugIns/wireshark diff --git a/vendor/github.com/xtaci/kcp-go/v5/wireshark/kcp_dissector.lua b/vendor/github.com/xtaci/kcp-go/v5/wireshark/kcp_dissector.lua new file mode 100644 index 00000000..67e4b9b2 --- /dev/null +++ b/vendor/github.com/xtaci/kcp-go/v5/wireshark/kcp_dissector.lua @@ -0,0 +1,82 @@ +-- create protocol +kcp_protocol = Proto("KCP", "KCP Protocol") + +-- fields for kcp +conv = ProtoField.uint32("kcp.conv", "conv", base.DEC) +cmd = ProtoField.uint8("kcp.cmd", "cmd", base.DEC) +frg = ProtoField.uint8("kcp.frg", "frg", base.DEC) +wnd = ProtoField.uint16("kcp.wnd", "wnd", base.DEC) +ts = ProtoField.uint32("kcp.ts", "ts", base.DEC) +sn = ProtoField.uint32("kcp.sn", "sn", base.DEC) +una = ProtoField.uint32("kcp.una", "una", base.DEC) +len = ProtoField.uint32("kcp.len", "len", base.DEC) + +kcp_protocol.fields = {conv, cmd, frg, wnd, ts, sn, una, len} + +-- dissect each udp packet +function kcp_protocol.dissector(buffer, pinfo, tree) + length = buffer:len() + if length == 0 then + return + end + + local offset_s = 8 + local first_sn = buffer(offset_s + 12, 4):le_int() + local first_len = buffer(offset_s + 20, 4):le_int() + local first_cmd_name = get_cmd_name(buffer(offset_s + 4, 1):le_int()) + local info = string.format("[%s] Sn=%d Kcplen=%d", first_cmd_name, first_sn, first_len) + + pinfo.cols.protocol = kcp_protocol.name + udp_info = string.gsub(tostring(pinfo.cols.info), "Len", "Udplen", 1) + pinfo.cols.info = string.gsub(udp_info, " U", info .. " U", 1) + + -- dssect multi kcp packet in udp + local offset = 8 + while offset < buffer:len() do + local conv_buf = buffer(offset + 0, 4) + local cmd_buf = buffer(offset + 4, 1) + local wnd_buf = buffer(offset + 6, 2) + local sn_buf = buffer(offset + 12, 4) + local len_buf = buffer(offset + 20, 4) + + local cmd_name = get_cmd_name(cmd_buf:le_int()) + local data_len = len_buf:le_int() + + local tree_title = + string.format( + "KCP Protocol, %s, Sn: %d, Conv: %d, Wnd: %d, Len: %d", + cmd_name, + sn_buf:le_int(), + conv_buf:le_int(), + wnd_buf:le_int(), + data_len + ) + local subtree = tree:add(kcp_protocol, buffer(), tree_title) + subtree:add_le(conv, conv_buf) + subtree:add_le(cmd, cmd_buf):append_text(" (" .. cmd_name .. ")") + subtree:add_le(frg, buffer(offset + 5, 1)) + subtree:add_le(wnd, wnd_buf) + subtree:add_le(ts, buffer(offset + 8, 4)) + subtree:add_le(sn, sn_buf) + subtree:add_le(una, buffer(offset + 16, 4)) + subtree:add_le(len, len_buf) + offset = offset + 24 + data_len + end +end + +function get_cmd_name(cmd_val) + if cmd_val == 81 then + return "PSH" + elseif cmd_val == 82 then + return "ACK" + elseif cmd_val == 83 then + return "ASK" + elseif cmd_val == 84 then + return "TELL" + end +end + +-- register kcp dissector to udp +local udp_port = DissectorTable.get("udp.port") +-- replace 8081 to the port for kcp +udp_port:add(8081, kcp_protocol) diff --git a/walletapi/daemon_communication.go b/walletapi/daemon_communication.go index 06789795..edd0e855 100644 --- a/walletapi/daemon_communication.go +++ b/walletapi/daemon_communication.go @@ -59,6 +59,19 @@ import "github.com/creachadair/jrpc2" // this global variable should be within wallet structure var Connected bool = false +var daemon_height int64 +var daemon_topoheight int64 + +// return daemon height +func Get_Daemon_Height() int64 { + return daemon_height +} + +// return topoheight of daemon +func Get_Daemon_TopoHeight() int64 { + return daemon_topoheight +} + var simulator bool // turns on simulator, which has 0 fees // there should be no global variables, so multiple wallets can run at the same time with different assset @@ -95,6 +108,7 @@ func Notify_broadcaster(req *jrpc2.Request) { NotifyHeightChange.L.Lock() NotifyHeightChange.Broadcast() NotifyHeightChange.L.Unlock() + go test_connectivity() case "MiniBlock": // we can skip this default: logger.V(1).Info("Notification received", "method", req.Method()) @@ -157,6 +171,8 @@ func test_connectivity() (err error) { if strings.ToLower(info.Network) == "simulator" { simulator = true } + daemon_height = info.Height + daemon_topoheight = info.TopoHeight logger.Info("successfully connected to daemon") return nil } @@ -398,6 +414,9 @@ func (w *Wallet_Memory) GetEncryptedBalanceAtTopoHeight(scid crypto.Hash, topohe if topoheight == -1 { w.Daemon_Height = uint64(result.DHeight) w.Daemon_TopoHeight = result.DTopoheight + + daemon_height = result.DHeight + daemon_topoheight = result.DTopoheight w.Merkle_Balance_TreeHash = result.DMerkle_Balance_TreeHash }