-
Notifications
You must be signed in to change notification settings - Fork 883
/
pool_generation_counter.go
152 lines (126 loc) · 4.28 KB
/
pool_generation_counter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"sync"
"sync/atomic"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Pool generation state constants.
const (
generationDisconnected int64 = iota
generationConnected
)
// generationStats represents the version of a pool. It tracks the generation number as well as the number of
// connections that have been created in the generation.
type generationStats struct {
generation uint64
numConns uint64
}
// poolGenerationMap tracks the version for each service ID present in a pool. For deployments that are not behind a
// load balancer, there is only one service ID: primitive.NilObjectID. For load-balanced deployments, each server behind
// the load balancer will have a unique service ID.
type poolGenerationMap struct {
// state must be accessed using the atomic package and should be at the beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
generationMap map[primitive.ObjectID]*generationStats
sync.Mutex
}
func newPoolGenerationMap() *poolGenerationMap {
pgm := &poolGenerationMap{
generationMap: make(map[primitive.ObjectID]*generationStats),
}
pgm.generationMap[primitive.NilObjectID] = &generationStats{}
return pgm
}
func (p *poolGenerationMap) connect() {
atomic.StoreInt64(&p.state, generationConnected)
}
func (p *poolGenerationMap) disconnect() {
atomic.StoreInt64(&p.state, generationDisconnected)
}
// addConnection increments the connection count for the generation associated with the given service ID and returns the
// generation number for the connection.
func (p *poolGenerationMap) addConnection(serviceIDPtr *primitive.ObjectID) uint64 {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
stats, ok := p.generationMap[serviceID]
if ok {
// If the serviceID is already being tracked, we only need to increment the connection count.
stats.numConns++
return stats.generation
}
// If the serviceID is untracked, create a new entry with a starting generation number of 0.
stats = &generationStats{
numConns: 1,
}
p.generationMap[serviceID] = stats
return 0
}
func (p *poolGenerationMap) removeConnection(serviceIDPtr *primitive.ObjectID) {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
stats, ok := p.generationMap[serviceID]
if !ok {
return
}
// If the serviceID is being tracked, decrement the connection count and delete this serviceID to prevent the map
// from growing unboundedly. This case would happen if a server behind a load-balancer was permanently removed
// and its connections were pruned after a network error or idle timeout.
stats.numConns--
if stats.numConns == 0 {
delete(p.generationMap, serviceID)
}
}
func (p *poolGenerationMap) clear(serviceIDPtr *primitive.ObjectID) {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
if stats, ok := p.generationMap[serviceID]; ok {
stats.generation++
}
}
func (p *poolGenerationMap) stale(serviceIDPtr *primitive.ObjectID, knownGeneration uint64) bool {
// If the map has been disconnected, all connections should be considered stale to ensure that they're closed.
if atomic.LoadInt64(&p.state) == generationDisconnected {
return true
}
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
if stats, ok := p.generationMap[serviceID]; ok {
return knownGeneration < stats.generation
}
return false
}
func (p *poolGenerationMap) getGeneration(serviceIDPtr *primitive.ObjectID) uint64 {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
if stats, ok := p.generationMap[serviceID]; ok {
return stats.generation
}
return 0
}
func (p *poolGenerationMap) getNumConns(serviceIDPtr *primitive.ObjectID) uint64 {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
if stats, ok := p.generationMap[serviceID]; ok {
return stats.numConns
}
return 0
}
func getServiceID(oid *primitive.ObjectID) primitive.ObjectID {
if oid == nil {
return primitive.NilObjectID
}
return *oid
}