In [1]:
import torch
from operator import itemgetter

In [2]:
def cyclic_group_generator(vocab_size, group_size, eq_indices):
    """
    :param vocab_size: size of the vocab
    :param group_size: size of the group
    :param eq_indices: set of  list of indices of the tokens for the equivariant words in the vocab; each list of indices of the same size as group size
    :return: a group generator of the required cyclic group consisting of all the equivariant word indices
    """
    g = {i: i for i in range(vocab_size)}  # generator initialized as id

    for i in range(group_size):
        next_group_element = (i + 1) % group_size
        # g[eq_indices[i]] = eq_indices[next_group_element]
        for j in range(len(eq_indices)):
            g[eq_indices[j][i]] = eq_indices[j][next_group_element]

    g['___size___'] = group_size  # add length of the group as a value
    return g

In [3]:
def cyclic_group(g, vocab_size, group_size):
    """
    :param g: cyclic group generator
    :param group_size: size of the group
    :return: return a list of elements of a cyclic group
    """
    # add id to the group G
    G = [{i: i for i in range(vocab_size)}]

    for i in range(group_size - 1):
        # apply the generator repeatedly to obtain the entire group
        curr_g = G[-1]
        next_g = {i: g[curr_g[i]] for i in range(vocab_size)}
        G.append(next_g)

    return G

In [4]:
def g_transform_data(data, G, device):
    '''
    :param data: any tensor data of input on which group is applied
    :param G: set of group elements
    :return: list of transformed data for equituning
    '''
    # print("Debugging function: g_transform_data")
    # print("  Group Elements:", G)

    data_shape = data.size()
    untransformed_data = data.view(-1)
    transformed_data = [untransformed_data]

    for i in range(len(G)-1):
        curr_g = G[i+1]
        current_data = torch.tensor(itemgetter(*(untransformed_data.tolist()))(curr_g), device=device)
        transformed_data.append(current_data)
        # print(f"  After applying group element {i}: {current_data}")

    transformed_data = torch.stack(transformed_data).view(len(G), data_shape[0], data_shape[1])
    transformed_data.to(device)

    return transformed_data

In [17]:
def g_inv_transform_prob_data(data_list, G):
    '''
    Note: Group actions are on batch_size x |V|, instead of batch_size x 1
    :param data: any tensor data
    :param g: group generator
    :return: list of transformed data for equituning
    '''
    # print("Debugging function: g_inv_transform_prob_data")
    output_data_list = data_list.clone()  # dim [group_size, batch_size, num_tokens, |V|]
    g_indices = []
    for g in G:
        print(f"g: {g}")
        g_index = [g[i] for i in range(len(g))]
        print(f"g_index: {g_index}")
        g_indices.append(g_index)
        print()

    # print("  Initial data list for inverse transformation:")
    # print(data_list)

    for i in range(len(data_list)):  # iterate over group size
        output_data_list[i, :, :, g_indices[i]] = output_data_list[i, :, :, :].clone()

    # print("  Final data list after inverse transformation:")
    # print(output_data_list)
    return output_data_list

In [6]:
# demographic: MAN vs WOMAN
# E: [[man, woman], [he, she]]
vocab_size = 7
group_size = 2
eq_indices = [[0, 1], [2, 3]]

In [7]:
group_generator = cyclic_group_generator(vocab_size, group_size, eq_indices)
group_generator

{0: 1, 1: 0, 2: 3, 3: 2, 4: 4, 5: 5, 6: 6, '___size___': 2}

In [8]:
G = cyclic_group(group_generator, vocab_size, group_size)
G

[{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6},
 {0: 1, 1: 0, 2: 3, 3: 2, 4: 4, 5: 5, 6: 6}]

In [9]:
# Data: two sentences
# sent1: man is smart -> [0, 4, 5]
# sent2: he is good -> [2, 4, 6]

data = torch.tensor([[0, 4, 5], [2, 4, 6]])
data.shape

torch.Size([2, 3])

In [10]:
transformed_data = g_transform_data(data, G, device='cpu')
transformed_data

tensor([[[0, 4, 5],
         [2, 4, 6]],

        [[1, 4, 5],
         [3, 4, 6]]])

In [11]:
simulated_model_output = transformed_data.clone().unsqueeze(-1).expand(-1, -1, -1, vocab_size)
simulated_model_output

tensor([[[[0, 0, 0, 0, 0, 0, 0],
          [4, 4, 4, 4, 4, 4, 4],
          [5, 5, 5, 5, 5, 5, 5]],

         [[2, 2, 2, 2, 2, 2, 2],
          [4, 4, 4, 4, 4, 4, 4],
          [6, 6, 6, 6, 6, 6, 6]]],


        [[[1, 1, 1, 1, 1, 1, 1],
          [4, 4, 4, 4, 4, 4, 4],
          [5, 5, 5, 5, 5, 5, 5]],

         [[3, 3, 3, 3, 3, 3, 3],
          [4, 4, 4, 4, 4, 4, 4],
          [6, 6, 6, 6, 6, 6, 6]]]])

In [33]:
simulated_model_output.shape

torch.Size([2, 2, 3, 7])

In [62]:
simulated_model_output[:, :, :, 0].view(-1, 3)

tensor([[0, 4, 5],
        [2, 4, 6],
        [1, 4, 5],
        [3, 4, 6]])

In [18]:
inverted_output = g_inv_transform_prob_data(simulated_model_output, G)
inverted_output

g: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6}
g_index: [0, 1, 2, 3, 4, 5, 6]

g: {0: 1, 1: 0, 2: 3, 3: 2, 4: 4, 5: 5, 6: 6}
g_index: [1, 0, 3, 2, 4, 5, 6]



tensor([[[[0, 0, 0, 0, 0, 0, 0],
          [4, 4, 4, 4, 4, 4, 4],
          [5, 5, 5, 5, 5, 5, 5]],

         [[2, 2, 2, 2, 2, 2, 2],
          [4, 4, 4, 4, 4, 4, 4],
          [6, 6, 6, 6, 6, 6, 6]]],


        [[[1, 1, 1, 1, 1, 1, 1],
          [4, 4, 4, 4, 4, 4, 4],
          [5, 5, 5, 5, 5, 5, 5]],

         [[3, 3, 3, 3, 3, 3, 3],
          [4, 4, 4, 4, 4, 4, 4],
          [6, 6, 6, 6, 6, 6, 6]]]])

In [13]:
def g_inv_transform_prob_data(data_list, G):
    output_data_list = data_list.clone()
    
    return output_data_list[:, :, :, 0].view(-1, 3)

In [23]:
def g_inv_transform_prob_data(data_list, G):
    output_data_list = data_list.clone()  # Clone to avoid modifying the original data

    # Revert transformations for each group element
    for idx, g in enumerate(G):
        for i in range(data_list.shape[1]):  # Iterating over each sequence element
            for j in range(data_list.shape[2]):  # Iterating over each feature/vocabulary element
                output_data_list[idx, i, j, :] = data_list[idx, i, g[j], :]

    return output_data_list

In [34]:
def g_inv_transform_prob_data(data_list, G):
    # Dimensions of data_list: [group_size, num_sentences, sentence_length, vocab_size]
    # Prepare the output tensor with the same dimensions
    #output_data_list = data_list
    
    # For each group transformation (there are as many transformations as the size of G)
    for idx, g in enumerate(G):
        # Invert the current group mapping
        inv_g = {v: k for k, v in g.items()}
        
        # Apply the inverted group mapping to reorder each word in each sentence
        for sent_idx in range(data_list.shape[1]):
            # Get the original positions for all word indices in the sentence
            original_indices = [inv_g[i] if i in inv_g else i for i in range(data_list.shape[2])]
            
            # Apply the inverse transformation correctly
            for word_idx, orig_idx in enumerate(original_indices):
                output_data_list[idx, sent_idx, word_idx, :] = data_list[idx, sent_idx, orig_idx, :]

    return output_data_list


In [38]:
G

[{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6},
 {0: 1, 1: 0, 2: 3, 3: 2, 4: 4, 5: 5, 6: 6}]

In [46]:
def g_inv_transform_prob_data(data_list, G):
    '''
    Note: Group actions are on batch_size x |V|, instead of batch_size x 1
    :param data: any tensor data
    :param g: group generator
    :return: list of transformed data for equituning
    '''
    # print("Debugging function: g_inv_transform_prob_data")
    output_data_list = data_list.clone()  # dim [group_size, batch_size, num_tokens, |V|]
    g_indices = []
    for g in G:
        # print(f"g: {g}")
        g_index = [g[i] for i in range(len(g))]
        # print(f"g_index: {g_index}")
        g_indices.append(g_index)
        # print()

    # print("  Initial data list for inverse transformation:")
    # print(data_list)
    for i in range(len(data_list)):  # iterate over group size
        print(f"i: {i}")
        print(f"  g_indices[i]: {g_indices[i]}")
        print(f"  output_data_list[i, :, :, g_indices[i]]: {output_data_list[i, :, :, g_indices[i]]}")
        print(f"  output_data_list[i, :, :, :]: {output_data_list[i, :, :, :]}")
        output_data_list[i, :, :, g_indices[i]] = output_data_list[i, :, :, :].clone()

    # print("  Final data list after inverse transformation:")
    # print(output_data_list)
    return output_data_list