diff --git a/cmd/shadowsocks-local/local.go b/cmd/shadowsocks-local/local.go index 0a588bd..dc8fcde 100644 --- a/cmd/shadowsocks-local/local.go +++ b/cmd/shadowsocks-local/local.go @@ -10,6 +10,7 @@ import ( "log" "net" "os" + "path" "strconv" ) @@ -143,7 +144,7 @@ func handleConnection(conn net.Conn, server string, encTbl *ss.EncryptTable) { var err error = nil if err = handShake(conn); err != nil { - log.Println("socks handshack:", err) + log.Println("socks handshake:", err) return } rawaddr, addr, err := getRequest(conn) @@ -195,6 +196,20 @@ func enoughOptions(config *ss.Config) bool { config.LocalPort != 0 && config.Password != "" } +func isFileExists(path string) (bool, error) { + stat, err := os.Stat(path) + if err == nil { + if stat.Mode()&os.ModeType == 0 { + return true, nil + } + return false, errors.New(path + " exists but is not regular file") + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + func main() { var configFile string var cmdConfig ss.Config @@ -208,11 +223,21 @@ func main() { flag.Parse() + exists, err := isFileExists(configFile) + // If no config file in current directory, try search it in the binary directory + // Note there's no portable way to detect the binary directory. + binDir := path.Dir(os.Args[0]) + if (!exists || err != nil) && binDir != "" && binDir != "." { + oldConfig := configFile + configFile = path.Join(binDir, "config.json") + log.Printf("%s not found, try config file %s\n", oldConfig, configFile) + } + config, err := ss.ParseConfig(configFile) if err != nil { enough := enoughOptions(&cmdConfig) if !(enough && os.IsNotExist(err)) { - log.Printf("error reading %s: %v\n", configFile, err) + log.Printf("error reading config file: %v\n", err) } if !enough { return diff --git a/cmd/shadowsocks-server/server.go b/cmd/shadowsocks-server/server.go index a41972e..e6cac64 100644 --- a/cmd/shadowsocks-server/server.go +++ b/cmd/shadowsocks-server/server.go @@ -10,8 +10,11 @@ import ( "log" "net" "os" + "os/signal" "strconv" + "sync" "sync/atomic" + "syscall" "time" ) @@ -95,7 +98,13 @@ func handleConnection(conn *ss.Conn) { debug.Println("connecting", host) remote, err := net.Dial("tcp", host) if err != nil { - debug.Println("error connecting to:", host, err) + if ne, ok := err.(*net.OpError); ok && (ne.Err == syscall.EMFILE || ne.Err == syscall.ENFILE) { + // log too many open file error + // EMFILE is process reaches open file limits, ENFILE is system limit + log.Println("dial error:", err) + } else { + debug.Println("error connecting to:", host, err) + } return } defer remote.Close() @@ -123,29 +132,130 @@ var tableCache = map[string]*ss.EncryptTable{} var tableGetCnt int32 func getTable(password string) (tbl *ss.EncryptTable) { - tbl, ok := tableCache[password] - if ok { - debug.Println("table cache hit for password:", password) - return + if tableCache != nil { + var ok bool + tbl, ok = tableCache[password] + if ok { + debug.Println("table cache hit for password:", password) + return + } + tbl = ss.GetTable(password) + tableCache[password] = tbl + } else { + tbl = ss.GetTable(password) } - tbl = ss.GetTable(password) - tableCache[password] = tbl return } +type PortListener struct { + password string + listener net.Listener +} + +type PasswdManager struct { + sync.Mutex + portListener map[string]*PortListener +} + +func (pm *PasswdManager) add(port, password string, listener net.Listener) { + pm.Lock() + pm.portListener[port] = &PortListener{password, listener} + pm.Unlock() +} + +func (pm *PasswdManager) get(port string) (pl *PortListener, ok bool) { + pm.Lock() + pl, ok = pm.portListener[port] + pm.Unlock() + return +} + +func (pm *PasswdManager) del(port string) { + pl, ok := pm.get(port) + if !ok { + return + } + pl.listener.Close() + pm.Lock() + delete(pm.portListener, port) + pm.Unlock() +} + +func (pm *PasswdManager) updatePortPasswd(port, password string) { + pl, ok := pm.get(port) + if !ok { + log.Printf("new port %s added\n", port) + } else { + if pl.password == password { + return + } + log.Printf("closing port %s to update password\n", port) + pl.listener.Close() + } + // run will add the new port listener to passwdManager. + // So there maybe concurrent access to passwdManager and we need lock to protect it. + go run(port, password) +} + +var passwdManager = PasswdManager{portListener: map[string]*PortListener{}} + +func updatePasswd() { + log.Println("updating password") + newconfig, err := ss.ParseConfig(configFile) + if err != nil { + log.Printf("error parsing config file %s to update password: %v\n", configFile, err) + return + } + oldconfig := config + config = newconfig + + if err = unifyPortPassword(config); err != nil { + return + } + for port, passwd := range config.PortPassword { + passwdManager.updatePortPasswd(port, passwd) + if oldconfig.PortPassword != nil { + delete(oldconfig.PortPassword, port) + } + } + // port password still left in the old config should be closed + for port, _ := range oldconfig.PortPassword { + log.Printf("closing port %s as it's deleted\n", port) + passwdManager.del(port) + } + log.Println("password updated") +} + +func waitSignal() { + var sigChan = make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGHUP) + for sig := range sigChan { + if sig == syscall.SIGHUP { + updatePasswd() + } else { + // is this going to happen? + log.Printf("caught signal %v, exit", sig) + os.Exit(0) + } + } +} + func run(port, password string) { ln, err := net.Listen("tcp", ":"+port) if err != nil { - log.Fatal(err) + log.Printf("try listening port %v: %v\n", port, err) + return } + passwdManager.add(port, password, ln) encTbl := getTable(password) atomic.AddInt32(&tableGetCnt, 1) - log.Printf("starting server at port %v ...\n", port) + log.Printf("server listening port %v ...\n", port) for { conn, err := ln.Accept() if err != nil { - log.Println("accept:", err) - continue + // listener maybe closed to update password + debug.Printf("accept error: %v\n", err) + return } go handleConnection(ss.NewConn(conn, encTbl)) } @@ -155,8 +265,26 @@ func enoughOptions(config *ss.Config) bool { return config.ServerPort != 0 && config.Password != "" } +func unifyPortPassword(config *ss.Config) (err error) { + if len(config.PortPassword) == 0 { // this handles both nil PortPassword and empty one + if !enoughOptions(config) { + log.Println("must specify both port and password") + return errors.New("not enough options") + } + port := strconv.Itoa(config.ServerPort) + config.PortPassword = map[string]string{port: config.Password} + } else { + if config.Password != "" || config.ServerPort != 0 { + log.Println("given port_password, ignore server_port and password option") + } + } + return +} + +var configFile string +var config *ss.Config + func main() { - var configFile string var cmdConfig ss.Config flag.StringVar(&configFile, "c", "config.json", "specify config file") @@ -167,7 +295,8 @@ func main() { flag.Parse() - config, err := ss.ParseConfig(configFile) + var err error + config, err = ss.ParseConfig(configFile) if err != nil { enough := enoughOptions(&cmdConfig) if !(enough && os.IsNotExist(err)) { @@ -183,22 +312,17 @@ func main() { } ss.SetDebug(debug) - if len(config.PortPassword) == 0 { - run(strconv.Itoa(config.ServerPort), config.Password) - } else { - if config.ServerPort != 0 { - log.Println("ignoring server_port and password option, only uses port_password") - } - for port, password := range config.PortPassword { - go run(port, password) - } - // Wait all ports have get it's encryption table - for int(tableGetCnt) != len(config.PortPassword) { - time.Sleep(1 * time.Second) - } - log.Println("all ports ready") - tableCache = nil // release memory - c := make(chan byte) - <-c // block forever + if err = unifyPortPassword(config); err != nil { + os.Exit(1) + } + for port, password := range config.PortPassword { + go run(port, password) + } + // Wait all ports have get it's encryption table + for int(tableGetCnt) != len(config.PortPassword) { + time.Sleep(1 * time.Second) } + log.Println("all ports ready") + tableCache = nil // release memory + waitSignal() }