In [41]:
import torch


def label_list_to_topology(labels: torch.Tensor):
    """
    Converts a list of per-position labels to a topology representation.
    This maps every sequence to list of where each new symbol start (the topology), e.g. AAABBBBCCC -> [(0,A),(3, B)(7,C)]

    Parameters
    ----------
    labels : list or torch.Tensor of ints
        List of labels.

    Returns
    -------
    list of torch.Tensor
        List of tensors that represents the topology.
    """

    if isinstance(labels, list):
        labels = torch.ByteTensor(labels)

    if isinstance(labels, torch.Tensor):
        zero_tensor = torch.ByteTensor([0])
        if labels.is_cuda:
            zero_tensor = zero_tensor.cuda()

        unique, count = torch.unique_consecutive(labels, return_counts=True)
        top_list = [torch.cat((zero_tensor, labels[0:1]))]
        prev_count = 0
        i = 0
        for _ in unique.split(1):
            if i == 0:
                i += 1
                continue
            prev_count += count[i - 1]
            top_list.append(torch.cat((prev_count.view(1), unique[i].view(1))))
            i += 1
        return top_list


# def is_topologies_equal(topology_a, topology_b, minimum_seqment_overlap=5):
#     """
#     Checks whether two topologies are equal.
#     E.g. [(0,A),(3, B)(7,C)]  is the same as [(0,A),(4, B)(7,C)]
#     But not the same as [(0,A),(3, C)(7,B)]

#     Parameters
#     ----------
#     topology_a : list of torch.Tensor
#         First topology. See label_list_to_topology.
#     topology_b : list of torch.Tensor
#         Second topology. See label_list_to_topology.
#     minimum_seqment_overlap : int
#         Minimum overlap between two segments to be considered equal.

#     Returns
#     -------
#     bool
#         True if topologies are equal, False otherwise.
#     """

#     if isinstance(topology_a[0], torch.ByteTensor):
#         topology_a = list([a.cpu().numpy() for a in topology_a])
#     if isinstance(topology_b[0], torch.ByteTensor):
#         topology_b = list([b.cpu().numpy() for b in topology_b])
#     if len(topology_a) != len(topology_b):
#         return False
#     for idx, (_position_a, label_a) in enumerate(topology_a):
#         if label_a != topology_b[idx][1]:
#             if (label_a in (1,2) and topology_b[idx][1] in (1,2)): # assume O == P
#                 print("Continued")
#                 continue
#             else:
#                 print("broke with 1,2")
#                 return False
#         if label_a in (3, 4, 5):
#             overlap_segment_start = max(topology_a[idx][0], topology_b[idx][0])
#             overlap_segment_end = min(topology_a[idx + 1][0], topology_b[idx + 1][0])
#             if label_a == 5:
#                 # Set minimum segment overlap to 3 for Beta regions
#                 minimum_seqment_overlap = 3
#             if overlap_segment_end - overlap_segment_start < minimum_seqment_overlap:
#                 print(overlap_segment_end)
#                 print(overlap_segment_start)

#                 return False
#     return True


type2key = {'I': 0, 'O':1, 'P': 2, 'S': 3, 'M':4, 'B': 5}
#they use: LABELS: Dict[str,int] = {'I': 0, 'O':1, 'P': 2, 'S': 3, 'M':4, 'B': 5}
a = torch.ByteTensor([1,1,1,3,3,3,3,2])
b1 = torch.ByteTensor([1,1,1,1,3,3,3,2])
b2 = torch.ByteTensor([1,1,1,2,2,2,3])

a = label_list_to_topology(a)
b1 = label_list_to_topology(b1)
b2 = label_list_to_topology(b2)
print(a)
print(b1)
print(b2)


def sequence_equality(topology_a,topology_b,minimum_segment_overlap=5):
    if isinstance(topology_a[0], torch.Tensor):
            print("Converted a to list!")
            topology_a = list([a.cpu().numpy() for a in topology_a])
    if isinstance(topology_b[0], torch.Tensor):
            print("Converted b to list!")
            topology_b = list([b.cpu().numpy() for b in topology_b])
    if len(topology_a) != len(topology_b):
            return False
    
    print(topology_a)
    print(topology_b)
    
    for idx, (pos_a,label_a) in enumerate(topology_a):
        print("checking: ",idx)
        pos_b, label_b = topology_b[idx]
        if(label_a!=label_b):
            if(label_a in (1,2) and label_b in (1,2)):
                continue
            else:
                return False
              
        if label_a in (3,4,5):
            overlap_start = max(pos_a,pos_b)
            overlap_end = min(topology_a[idx+1][0],topology_b[idx+1][0])
        
            if label_a == 5:
                minimum_segment_overlap = 3
            if overlap_end-overlap_start < minimum_segment_overlap:
                return False
    return True

        

    # for idx, (position_a,label_a) in enumerate(topology_a):
    #     print("A label: ",label_a)
    #     print("B label: ",topology_b[idx][1])
    #     if label_a!=topology_b[idx][1]:
    #         if (label_a in (4,5) and b[idx][1] in (4,5)): #case O,P: assume O==P
    #             print("Continued")
    #             continue
    #         else:
    #             print("Falsed")
    #             return False
    #     if label_a in (0,2,3): #case: S,M,B
    #         overlap_segment_start = max(topology_a[idx][0],topology_b[idx][0])
    #         print(overlap_segment_start)
    #         overlap_segment_end = min(topology_a[idx+1][0],topology_b[idx+1][0])
    #         print(overlap_segment_end)
    #         overlap = overlap_segment_end-overlap_segment_start
    #         print("Overlap becomes: ",overlap)
    #         #overlap = (b[idx:idx+5]==a[idx:idx+5]).sum()
    #         if(label_a==3):
    #             minimum_overlap = 3
    #         if overlap < minimum_overlap:
    #             return False
    # return True

sequence_equality(a,b1)

sequence_equality(a,b2)

[tensor([0, 1], dtype=torch.uint8), tensor([3, 3]), tensor([7, 2])]
[tensor([0, 1], dtype=torch.uint8), tensor([4, 3]), tensor([7, 2])]
[tensor([0, 1], dtype=torch.uint8), tensor([3, 2]), tensor([6, 3])]
Converted a to list!
Converted b to list!
[array([0, 1], dtype=uint8), array([3, 3]), array([7, 2])]
[array([0, 1], dtype=uint8), array([4, 3]), array([7, 2])]
checking:  0
checking:  1
Converted a to list!
Converted b to list!
[array([0, 1], dtype=uint8), array([3, 3]), array([7, 2])]
[array([0, 1], dtype=uint8), array([3, 2]), array([6, 3])]
checking:  0
checking:  1


False

In [3]:

    

print(label_list_to_topology(a))
a = torch.Tensor([1,1,1,2,3,4,1,1,2])

[tensor([0., 1.]), tensor([3., 2.]), tensor([4., 3.]), tensor([5., 4.]), tensor([6., 1.]), tensor([8., 2.])]
