diff --git a/core/node/event_handlers.go b/core/node/event_handlers.go index ecb9d1f25f..74a330d71d 100644 --- a/core/node/event_handlers.go +++ b/core/node/event_handlers.go @@ -32,8 +32,8 @@ func (n *Node) handleContactRequest(ctx context.Context, input *entity.Event) er // FIXME: validate input sql := n.sql(ctx) - _, err = bsql.FindContact(sql, &entity.Contact{ID: attrs.Me.ID}) - if err == nil { + contact, err := bsql.FindContact(sql, &entity.Contact{ID: attrs.Me.ID}) + if err == nil && contact.Status != entity.Contact_Unknown { return errorcodes.ErrContactReqExisting.New() } diff --git a/core/node/nodeapi.go b/core/node/nodeapi.go index 61c8ceffec..24505b1f61 100644 --- a/core/node/nodeapi.go +++ b/core/node/nodeapi.go @@ -228,11 +228,11 @@ func (n *Node) ContactRequest(ctx context.Context, req *node.ContactRequestInput sql := n.sql(ctx) contact, err := bsql.FindContact(sql, req.ToContact()) - if errors.Cause(err) == gorm.ErrRecordNotFound { + if errors.Cause(err) == gorm.ErrRecordNotFound || contact.Status == entity.Contact_Unknown { // save contact in database contact = req.ToContact() contact.Status = entity.Contact_IsRequested - if err = sql.Set("gorm:association_autoupdate", true).Save(contact).Error; err != nil { + if err = bsql.ContactSave(sql, contact); err != nil { return nil, errorcodes.ErrDbCreate.Wrap(err) } } else if err != nil { @@ -250,7 +250,7 @@ func (n *Node) ContactRequest(ctx context.Context, req *node.ContactRequestInput } else if contact.Status == entity.Contact_Myself { return nil, errorcodes.ErrContactReqMyself.New() - } else if contact.Status != entity.Contact_Unknown { + } else { return nil, errorcodes.ErrContactReqExisting.New() } diff --git a/core/sql/helpers.go b/core/sql/helpers.go index 686752ee85..eda67d5c68 100644 --- a/core/sql/helpers.go +++ b/core/sql/helpers.go @@ -180,3 +180,20 @@ func ConversationSave(db *gorm.DB, c *entity.Conversation) error { return nil } + +func ContactSave(db *gorm.DB, c *entity.Contact) error { + if err := db.Save(c).Error; err != nil { + logger().Error(fmt.Sprintf("cannot save contact %+v, err: %+v", c, err.Error())) + return err + } + + c.Devices = append(c.Devices, &entity.Device{ID: c.ID, ContactID: c.ID}) + for _, device := range c.Devices { + if err := db.Save(device).Error; err != nil { + logger().Error(fmt.Sprintf("cannot save devices %+v, err %+v", device, err.Error())) + return err + } + } + + return nil +}