Skip to content

Commit

Permalink
feat: enable local commands without lms (#486)
Browse files Browse the repository at this point in the history
* feat: enable local config commands without LMS

* fix: move closures to after execution

* refactor: remove status channel as it is not used

---------

Co-authored-by: Matt C. Primrose <matt.c.primrose@intel.com>
  • Loading branch information
rsdmike and matt-primrose committed Jun 20, 2024
1 parent b471edf commit 8b7e7b1
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 65 deletions.
3 changes: 0 additions & 3 deletions internal/flags/maintenance.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,6 @@ func (f *Flags) handleMaintenanceSyncIP() error {
continue
}
addrs, _ := f.netEnumerator.InterfaceAddrs(&i)
if err != nil {
continue
}
for _, address := range addrs {
if ipnet, ok := address.(*net.IPNet); ok &&
ipnet.IP.To4() != nil &&
Expand Down
19 changes: 5 additions & 14 deletions internal/lm/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"encoding/binary"
"rpc/pkg/pthi"
"sync"
"time"

"github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/apf"
Expand All @@ -22,7 +23,7 @@ type LMEConnection struct {
retries int
}

func NewLMEConnection(data chan []byte, errors chan error, status chan bool) *LMEConnection {
func NewLMEConnection(data chan []byte, errors chan error, wg *sync.WaitGroup) *LMEConnection {
lme := &LMEConnection{
ourChannel: 1,
}
Expand All @@ -31,7 +32,7 @@ func NewLMEConnection(data chan []byte, errors chan error, status chan bool) *LM
DataBuffer: data,
ErrorBuffer: errors,
Tempdata: []byte{},
Status: status,
WaitGroup: wg,
}

return lme
Expand Down Expand Up @@ -65,7 +66,7 @@ func (lme *LMEConnection) Connect() error {
} else {
lme.ourChannel = channel
}

lme.Session.WaitGroup.Add(1)
bin_buf := apf.ChannelOpen(lme.ourChannel)
err := lme.Command.Send(bin_buf.Bytes(), uint32(bin_buf.Len()))
if err != nil {
Expand Down Expand Up @@ -133,21 +134,11 @@ func (lme *LMEConnection) Listen() {
lme.Session.DataBuffer <- lme.Session.Tempdata
lme.Session.Tempdata = []byte{}
var bin_buf bytes.Buffer
// var windowAdjust apf.APF_CHANNEL_WINDOW_ADJUST_MESSAGE
// if lme.Session.RXWindow > 1024 { // TODO: Check this
// windowAdjust = apf.ChannelWindowAdjust(lme.Session.RecipientChannel, lme.Session.RXWindow)
// lme.Session.RXWindow = 0
// binary.Write(&bin_buf, binary.BigEndian, windowAdjust.MessageType)
// binary.Write(&bin_buf, binary.BigEndian, windowAdjust.RecipientChannel)
// lme.Command.Call(bin_buf.Bytes(), uint32(bin_buf.Len()))
// }

channelData := apf.ChannelClose(lme.Session.SenderChannel)
binary.Write(&bin_buf, binary.BigEndian, channelData.MessageType)
binary.Write(&bin_buf, binary.BigEndian, channelData.RecipientChannel)

lme.Command.Send(bin_buf.Bytes(), uint32(bin_buf.Len()))
lme.Session.Status <- true
}()
for {
result2, bytesRead, err2 := lme.Command.Receive()
Expand All @@ -167,7 +158,7 @@ func (lme *LMEConnection) Listen() {
}
}

// Close closes the LMS socket connection
// Close closes the LME connection
func (lme *LMEConnection) Close() error {
log.Debug("closing connection to lme")
lme.Command.Close()
Expand Down
32 changes: 20 additions & 12 deletions internal/lm/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package lm
import (
"errors"
"rpc/pkg/pthi"
"sync"
"testing"

"github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/apf"
Expand Down Expand Up @@ -52,11 +53,10 @@ func Test_NewLMEConnection(t *testing.T) {
resetMock()
lmDataChannel := make(chan []byte)
lmErrorChannel := make(chan error)
lmStatusChannel := make(chan bool)
lme := NewLMEConnection(lmDataChannel, lmErrorChannel, lmStatusChannel)
wg := &sync.WaitGroup{}
lme := NewLMEConnection(lmDataChannel, lmErrorChannel, wg)
assert.Equal(t, lmDataChannel, lme.Session.DataBuffer)
assert.Equal(t, lmErrorChannel, lme.Session.ErrorBuffer)
assert.Equal(t, lmStatusChannel, lme.Session.Status)
}
func TestLMEConnection_Initialize(t *testing.T) {
resetMock()
Expand Down Expand Up @@ -96,8 +96,10 @@ func TestLMEConnection_Initialize(t *testing.T) {
sendError = tt.sendErr
initError = tt.initErr
lme := &LMEConnection{
Command: pthiVar,
Session: &apf.Session{},
Command: pthiVar,
Session: &apf.Session{
WaitGroup: &sync.WaitGroup{},
},
ourChannel: 1,
}
if err := lme.Initialize(); (err != nil) != tt.wantErr {
Expand All @@ -112,9 +114,10 @@ func Test_Send(t *testing.T) {
sendBytesWritten = 14

lme := &LMEConnection{
Command: pthiVar,
Session: &apf.Session{},
ourChannel: 1,
Command: pthiVar,
Session: &apf.Session{
WaitGroup: &sync.WaitGroup{},
}, ourChannel: 1,
}
data := []byte("hello")
err := lme.Send(data)
Expand All @@ -124,8 +127,10 @@ func Test_Connect(t *testing.T) {
resetMock()
sendBytesWritten = 54
lme := &LMEConnection{
Command: pthiVar,
Session: &apf.Session{},
Command: pthiVar,
Session: &apf.Session{
WaitGroup: &sync.WaitGroup{},
},
ourChannel: 1,
}
err := lme.Connect()
Expand All @@ -136,8 +141,10 @@ func Test_Connect_With_Error(t *testing.T) {
sendError = errors.New("no such device")
sendBytesWritten = 54
lme := &LMEConnection{
Command: pthiVar,
Session: &apf.Session{},
Command: pthiVar,
Session: &apf.Session{
WaitGroup: &sync.WaitGroup{},
},
ourChannel: 1,
}
err := lme.Connect()
Expand All @@ -154,6 +161,7 @@ func Test_Listen(t *testing.T) {
DataBuffer: lmDataChannel,
ErrorBuffer: lmErrorChannel,
Status: make(chan bool),
WaitGroup: &sync.WaitGroup{},
},
ourChannel: 1,
}
Expand Down
63 changes: 38 additions & 25 deletions internal/local/amt/localTransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,28 @@ import (
"io"
"net/http"
"rpc/internal/lm"
"sync"

"github.com/sirupsen/logrus"
)

// LocalTransport - Your custom net.Conn implementation
type LocalTransport struct {
local lm.LocalMananger
data chan []byte
errors chan error
status chan bool
local lm.LocalMananger
data chan []byte
errors chan error
status chan bool
waitGroup *sync.WaitGroup
}

func NewLocalTransport() *LocalTransport {
lmDataChannel := make(chan []byte)
lmErrorChannel := make(chan error)
lmStatus := make(chan bool)
waiter := &sync.WaitGroup{}
lm := &LocalTransport{
local: lm.NewLMEConnection(lmDataChannel, lmErrorChannel, lmStatus),
data: lmDataChannel,
errors: lmErrorChannel,
status: lmStatus,
local: lm.NewLMEConnection(lmDataChannel, lmErrorChannel, waiter),
data: lmDataChannel,
errors: lmErrorChannel,
waitGroup: waiter,
}
// defer lm.local.Close()
// defer close(lmDataChannel)
Expand All @@ -49,20 +50,18 @@ func NewLocalTransport() *LocalTransport {

// Custom dialer function
func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) {
//Something comes here...Maybe
go l.local.Listen()

// send channel open
err := l.local.Connect()
//Something comes here...Maybe
go l.local.Listen()

if err != nil {
logrus.Error(err)
return nil, err
}
// wait for channel open confirmation
<-l.status
l.waitGroup.Wait()
logrus.Trace("Channel open confirmation received")

// Serialize the HTTP request to raw form
rawRequest, err := serializeHTTPRequest(r)
if err != nil {
Expand All @@ -71,22 +70,30 @@ func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) {
}

var responseReader *bufio.Reader
// send our data to LMX
err = l.local.Send(rawRequest)

err = l.local.Send([]byte(rawRequest))
if err != nil {
logrus.Error(err)
return nil, err
}

for dataFromLM := range l.data {
if len(dataFromLM) > 0 {
logrus.Debug("received data from LME")
logrus.Trace(string(dataFromLM))

// /<-l.status
responseReader = bufio.NewReader(bytes.NewReader(dataFromLM))
break
Loop:
for {
select {
case dataFromLM := <-l.data:
if len(dataFromLM) > 0 {
logrus.Debug("received data from LME")
logrus.Trace(string(dataFromLM))
responseReader = bufio.NewReader(bytes.NewReader(dataFromLM))
break Loop
}
case errFromLMS := <-l.errors:
if errFromLMS != nil {
logrus.Error("error from LMS")
break Loop
}
}

}

response, err := http.ReadResponse(responseReader, r)
Expand All @@ -101,6 +108,8 @@ func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) {
func serializeHTTPRequest(r *http.Request) ([]byte, error) {
var reqBuffer bytes.Buffer

r.Header.Set("Transfer-Encoding", "chunked")

// Write request line
reqLine := fmt.Sprintf("%s %s %s\r\n", r.Method, r.URL.RequestURI(), r.Proto)
reqBuffer.WriteString(reqLine)
Expand All @@ -115,8 +124,12 @@ func serializeHTTPRequest(r *http.Request) ([]byte, error) {
if err != nil {
return nil, err
}
length := fmt.Sprintf("%x", len(bodyBytes))
bodyBytes = append([]byte(length+"\r\n"), bodyBytes...)
bodyBytes = append(bodyBytes, []byte("\r\n0\r\n\r\n")...)
// Important: Replace the body so it can be read again later if needed
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

reqBuffer.Write(bodyBytes)
}

Expand Down
16 changes: 15 additions & 1 deletion internal/local/amt/wsman.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package amt

import (
"encoding/base64"
"net"
"rpc/pkg/utils"

"github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman/ips/hostbasedsetup"
"github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman/ips/ieee8021x"
"github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman/ips/optin"
"github.com/sirupsen/logrus"
)

type WSMANer interface {
Expand Down Expand Up @@ -94,7 +96,6 @@ func NewGoWSMANMessages(lmsAddress string) *GoWSMANMessages {
}

func (g *GoWSMANMessages) SetupWsmanClient(username string, password string, logAMTMessages bool) {

clientParams := client.Parameters{
Target: g.target,
Username: username,
Expand All @@ -103,6 +104,19 @@ func (g *GoWSMANMessages) SetupWsmanClient(username string, password string, log
UseTLS: false,
LogAMTMessages: logAMTMessages,
}
logrus.Info("Attempting to connect to LMS...")
port := utils.LMSPort
if clientParams.UseTLS {
port = client.TLSPort
}
con, err := net.Dial("tcp4", utils.LMSAddress+":"+port)
if err != nil {
logrus.Info("Failed to connect to LMS, using local transport instead.")
clientParams.Transport = NewLocalTransport()
} else {
logrus.Info("Successfully connected to LMS.")
con.Close()
}
g.wsmanMessages = wsman.NewMessages(clientParams)
}

Expand Down
Loading

0 comments on commit 8b7e7b1

Please sign in to comment.