diff --git a/runtime/list.go b/runtime/list.go index 07819d43..b89ac369 100644 --- a/runtime/list.go +++ b/runtime/list.go @@ -174,31 +174,21 @@ func listRemove(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { if raised := checkMethodArgs(f, "remove", args, ListType, ObjectType); raised != nil { return nil, raised } + value := args[1] l := toListUnsafe(args[0]) l.mutex.Lock() - value := args[1] - found := false - var raised *BaseException - for i, elem := range l.elems { - var eq *Object - if eq, raised = Eq(f, elem, value); raised != nil { - break - } - if found, raised = IsTrue(f, eq); raised != nil { - break - } - if found { - l.elems = append(l.elems[:i], l.elems[i+1:]...) - break + index, raised := seqFindElem(f, l.elems, value) + if raised == nil { + if index != -1 { + l.elems = append(l.elems[:index], l.elems[index+1:]...) + } else { + raised = f.RaiseType(ValueErrorType, "list.remove(x): x not in list") } } l.mutex.Unlock() if raised != nil { return nil, raised } - if !found { - return nil, f.RaiseType(ValueErrorType, "list.remove(x): x not in list") - } return None, nil } @@ -338,6 +328,50 @@ func listNE(f *Frame, v, w *Object) (*Object, *BaseException) { return listCompare(f, toListUnsafe(v), w, NE) } +func listIndex(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { + expectedTypes := []*Type{ListType, ObjectType, ObjectType, ObjectType} + argc := len(args) + var raised *BaseException + if argc == 2 || argc == 3 { + expectedTypes = expectedTypes[:argc] + } + if raised = checkMethodArgs(f, "index", args, expectedTypes...); raised != nil { + return nil, raised + } + l := toListUnsafe(args[0]) + l.mutex.RLock() + numElems := len(l.elems) + start, stop := 0, numElems + if argc > 2 { + start, raised = IndexInt(f, args[2]) + if raised != nil { + l.mutex.RUnlock() + return nil, raised + } + } + if argc > 3 { + stop, raised = IndexInt(f, args[3]) + if raised != nil { + l.mutex.RUnlock() + return nil, raised + } + } + start, stop = adjustIndex(start, stop, numElems) + value := args[1] + index := -1 + if start < numElems && start < stop { + index, raised = seqFindElem(f, l.elems[start:stop], value) + } + l.mutex.RUnlock() + if raised != nil { + return nil, raised + } + if index == -1 { + return nil, f.RaiseType(ValueErrorType, fmt.Sprintf("%v is not in list", value)) + } + return NewInt(index + start).ToObject(), nil +} + func listPop(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { argc := len(args) expectedTypes := []*Type{ListType, ObjectType} @@ -432,6 +466,7 @@ func initListType(dict map[string]*Object) { dict["append"] = newBuiltinFunction("append", listAppend).ToObject() dict["count"] = newBuiltinFunction("count", listCount).ToObject() dict["extend"] = newBuiltinFunction("extend", listExtend).ToObject() + dict["index"] = newBuiltinFunction("index", listIndex).ToObject() dict["insert"] = newBuiltinFunction("insert", listInsert).ToObject() dict["pop"] = newBuiltinFunction("pop", listPop).ToObject() dict["remove"] = newBuiltinFunction("remove", listRemove).ToObject() diff --git a/runtime/list_test.go b/runtime/list_test.go index 55414bf7..a53e5a79 100644 --- a/runtime/list_test.go +++ b/runtime/list_test.go @@ -91,6 +91,37 @@ func TestListCount(t *testing.T) { } } } +func TestListIndex(t *testing.T) { + intIndexType := newTestClass("IntIndex", []*Type{ObjectType}, newStringDict(map[string]*Object{ + "__index__": newBuiltinFunction("__index__", func(f *Frame, _ Args, _ KWArgs) (*Object, *BaseException) { + return NewInt(0).ToObject(), nil + }).ToObject(), + })) + cases := []invokeTestCase{ + // {args: wrapArgs(newTestList(), 1, "foo"), wantExc: mustCreateException(TypeErrorType, "slice indices must be integers or None or have an __index__ method")}, + {args: wrapArgs(newTestList(10, 20, 30), 20), want: NewInt(1).ToObject()}, + {args: wrapArgs(newTestList(10, 20, 30), 20, newObject(intIndexType)), want: NewInt(1).ToObject()}, + {args: wrapArgs(newTestList(0, "foo", "bar"), "foo"), want: NewInt(1).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 3), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 2.0, 2, 3, 4, 2, 1, "foo"), 3, 3), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 4), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, 4), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, 3), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, -2), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, -1), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, -1), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, -2), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 0, 999), want: NewInt(3).ToObject()}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), "foo", 0, 999), wantExc: mustCreateException(ValueErrorType, "'foo' is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 999), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + {args: wrapArgs(newTestList(0, 1, 2, 3, 4), 3, 5, 0), wantExc: mustCreateException(ValueErrorType, "3 is not in list")}, + } + for _, cas := range cases { + if err := runInvokeMethodTestCase(ListType, "index", &cas); err != "" { + t.Error(err) + } + } +} func TestListRemove(t *testing.T) { fun := newBuiltinFunction("TestListRemove", func(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { diff --git a/runtime/seq.go b/runtime/seq.go index 8ac8f08b..1509be28 100644 --- a/runtime/seq.go +++ b/runtime/seq.go @@ -173,6 +173,23 @@ func seqFindFirst(f *Frame, iterable *Object, pred func(*Object) (bool, *BaseExc return false, nil } +func seqFindElem(f *Frame, elems []*Object, o *Object) (int, *BaseException) { + for i, elem := range elems { + eq, raised := Eq(f, elem, o) + if raised != nil { + return -1, raised + } + found, raised := IsTrue(f, eq) + if raised != nil { + return -1, raised + } + if found { + return i, nil + } + } + return -1, nil +} + func seqForEach(f *Frame, iterable *Object, callback func(*Object) *BaseException) *BaseException { iter, raised := Iter(f, iterable) if raised != nil {