In [1]:
import random, os
import numpy as np
import torch

random_seed = 42 

def seed_everything(seed: int):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(random_seed)

In [2]:
# Deactivate activation values
class Deactivator:
    def __init__(self, target_mlp_layer, selected_neurons):
        self.target_mlp_layer = target_mlp_layer
        self.selected_neurons = selected_neurons
        self.deactivation_hook = target_mlp_layer.register_forward_hook(self.deactivate)
    
    def deactivate(self, module, input, output):
        output[:, :, self.selected_neurons] *= 0
        return output
    
    def remove_hook(self):
        self.deactivation_hook.remove()


In [3]:
from transformers import AutoModel, AutoTokenizer

In [4]:

import torch

def find_last_token_pos(mask):
    reversed_mask = mask.flip(1)
    last_indices = mask.size(1) - 1 - reversed_mask.argmax(dim=1)
    last_indices[reversed_mask.sum(1) == 0] = 0
    return last_indices

def extract_last_token_repr(hidden_state, last_token_indices):
    batch_indices = torch.arange(len(hidden_state))
    return hidden_state[batch_indices, last_token_indices].unsqueeze(1)

In [5]:
import torch.nn.functional as F
from tqdm import tqdm
import einops
import numpy as np

def collect_semantic_sim(model, tokenizer, inputs_1, inputs_2, batch_size, patched_neurons_per_layer=None):
    all_cos_sims = []
    model.eval()
    with torch.no_grad():
        for batch_start in tqdm(range(0, len(inputs_1), batch_size)):
            text_batch_1 = inputs_1[batch_start:min(batch_start+batch_size,len(inputs_1))]
            tokenized_batch_1 = tokenizer(text_batch_1, padding=True, return_tensors='pt')
            attention_mask_1 = tokenized_batch_1['attention_mask']
            last_token_pos_1 = find_last_token_pos(attention_mask_1)
            input_ids_1 = tokenized_batch_1['input_ids'].to(model.device)
            attention_mask_1 = attention_mask_1.to(model.device)
            last_hidden_state_1 = model(input_ids=input_ids_1, attention_mask=attention_mask_1, output_hidden_states=True).hidden_states[-1].detach().cpu()
            last_token_representation_1  = extract_last_token_repr(last_hidden_state_1, last_token_pos_1)

            deactivators = []
            if patched_neurons_per_layer is not None:
                if 'LlamaForCausalLM' in str(type(model)) or 'CohereForCausalLM' in str(type(model)) or 'Qwen2ForCausalLM' in str(type(model)):
                    deactivators = [Deactivator(layer.mlp.act_fn, patched_neurons_per_layer[idx]) for idx, layer in enumerate(model.model.layers) if idx in patched_neurons_per_layer]
                elif 'BloomForCausalLM' in str(type(model)):
                    deactivators = [Deactivator(layer.mlp.gelu_impl, patched_neurons_per_layer[idx]) for idx, layer in enumerate(model.transformer.h) if idx in patched_neurons_per_layer]
                else:
                    print('model is not supported!')
                    return []
            text_batch_2 = inputs_2[batch_start:min(batch_start+batch_size,len(inputs_1))]
            tokenized_batch_2 = tokenizer(text_batch_2, padding=True, return_tensors='pt')
            attention_mask_2 = tokenized_batch_2['attention_mask']
            last_token_pos_2 = find_last_token_pos(attention_mask_2)
            input_ids_2 = tokenized_batch_2['input_ids'].to(model.device)
            attention_mask_2 = attention_mask_2.to(model.device)
            last_hidden_state_2 = model(input_ids=input_ids_2, attention_mask=attention_mask_2, output_hidden_states=True).hidden_states[-1].detach().cpu()

            for deactivator in deactivators:
                deactivator.remove_hook()

            last_token_representation_2  = extract_last_token_repr(last_hidden_state_2, last_token_pos_2)
            dot_prod = einops.einsum(last_token_representation_2, last_token_representation_1, 'b t d, b t d -> b').detach().cpu()

            norm_1 = torch.norm(last_token_representation_1, dim=[-2,-1]).detach().cpu()
            norm_2 = torch.norm(last_token_representation_1, dim=[-2,-1]).detach().cpu()
            eps = 1e-8
            normalizer = norm_1*norm_2
            cos_sim = torch.divide(dot_prod, normalizer)
            cos_sim_lst = [float(val.item()) for val in cos_sim]
            all_cos_sims.extend(cos_sim_lst)
    torch.cuda.empty_cache()
    return all_cos_sims

In [6]:
import einops

In [7]:
import pickle

neuron_file = 'natural_codemixed_sensitive_neurons.pkl'

with open(neuron_file, 'rb') as f:
    neuron_info = pickle.load(f)

In [9]:
import random

def generate_random_patched_neurons(acts, model_name, out_file, k=2500):
    random_acts = acts[model_name].copy()[::-1]
    random.shuffle(random_acts)
    selected_neurons = random.sample(random_acts, k)
    selected_neurons_per_layer = dict()
    for neuron_info in selected_neurons:
        layer = neuron_info['layer']
        neuron_pos = neuron_info['neuron_pos']
        if layer not in selected_neurons_per_layer:
            selected_neurons_per_layer[layer] = []
        selected_neurons_per_layer[layer].append(neuron_pos)
    
    for layer in selected_neurons_per_layer.keys():
        selected_neurons_per_layer[layer] = sorted(selected_neurons_per_layer[layer])
    
    with open(out_file, 'wb') as f:
        pickle.dump(selected_neurons_per_layer, f)
    

In [25]:
generate_random_patched_neurons(cm_acts, 'aya-expanse-8b', 'random-neurons-aya-5.pkl')

In [32]:
generate_random_patched_neurons(cm_acts, 'bloom-7b1', 'random-neurons-bloom-5.pkl')

In [37]:
generate_random_patched_neurons(cm_acts, 'Llama-3.2-3B', 'random-neurons-llama3-5.pkl')

In [42]:
generate_random_patched_neurons(cm_acts, 'Qwen2.5-7B', 'random-neurons-qwen-5.pkl')

In [8]:
def pick_topk_neurons(all_neuron_acts, model_name, top_k=2500):
    all_topk_neuron_infos = [neuron_acts[model_name][:top_k] for neuron_acts in all_neuron_acts]
    selected_neurons_per_layer = dict()
    selected_neurons_name = []
    for i in range(top_k):
        for neuron_acts in all_topk_neuron_infos:
            neuron_info = neuron_acts[i]
            neuron_name = neuron_info['name']
            layer = neuron_info['layer']
            neuron_pos = neuron_info['neuron_pos']
            if neuron_name not in selected_neurons_name:
                selected_neurons_name.append(neuron_name)
                if layer not in selected_neurons_per_layer:
                    selected_neurons_per_layer[layer] = []
                selected_neurons_per_layer[layer].append(neuron_pos)
            if len(selected_neurons_name)==top_k:
                break
        if len(selected_neurons_name)==top_k:
            break
    
    for layer in selected_neurons_per_layer.keys():
        selected_neurons_per_layer[layer] = sorted(selected_neurons_per_layer[layer])
    return selected_neurons_per_layer



In [9]:
def load_from_pkl(filepath):
    with open(filepath, 'rb') as f:
        obj = pickle.load(f)
    return obj

In [10]:
def save_to_pkl(data, filepath):
    with open(filepath, 'wb') as f:
        pickle.dump(data, f)
    

In [11]:
cm_acts = load_from_pkl('natural_codemixed_sensitive_neurons.pkl')
synth_0_25_cm_acts = load_from_pkl('synth_0.25_codemixed_sensitive_neurons.pkl')
synth_0_5_cm_acts = load_from_pkl('synth_0.5_codemixed_sensitive_neurons.pkl')
synth_0_75_cm_acts = load_from_pkl('synth_0.75_codemixed_sensitive_neurons.pkl')
en_acts = load_from_pkl('english_sensitive_neurons.pkl')
hi_acts = load_from_pkl('hindi_sensitive_neurons.pkl')


# aya_cm_en_hi_acts = pick_topk_neurons([cm_acts, en_acts, hi_acts], 'aya-expanse-8b')
# aya_en_hi_acts = pick_topk_neurons([en_acts, hi_acts], 'aya-expanse-8b')
# aya_cm_acts = pick_topk_neurons([cm_acts], 'aya-expanse-8b')

In [12]:
ar_acts = load_from_pkl('arabic_sensitive_neurons.pkl')

In [15]:
cm_types = ['natural', 'synth_0.25', 'synth_0.5', 'synth_0.75']
mono_types = ['en-hi', 'en-ar']
for model_name in synth_0_25_cm_acts.keys():
    for cm_type, current_cm_acts in zip(cm_types, [cm_acts, synth_0_25_cm_acts, synth_0_5_cm_acts, synth_0_75_cm_acts]):
        model_cm_acts = pick_topk_neurons([current_cm_acts], model_name)
        save_to_pkl(model_cm_acts, f"{model_name}_{cm_type}_codemixing.pkl")
        for mono_type, current_mono_acts in zip(mono_types, [[en_acts, hi_acts], [en_acts, ar_acts]]):
            combined_acts = [current_cm_acts] + current_mono_acts
            model_cm_mono_acts = pick_topk_neurons(combined_acts, model_name)
            model_mono_acts = pick_topk_neurons(current_mono_acts, model_name)
            save_to_pkl(model_cm_mono_acts, f"{model_name}_{cm_type}_codemixing_lang_{mono_type}.pkl")
            save_to_pkl(model_mono_acts, f"{model_name}_lang_{mono_type}.pkl")
        
        



In [None]:
dev='datasets/ArzEn-ST_arabic-english-arablish/data/dev.txt'
test='datasets/ArzEn-ST_arabic-english-arablish/data/test.txt'

dev_cm = []
dev_en = []
dev_ar = []
with open(dev, 'r') as f:
    lines = f.readlines()
    for line in lines:
        cm, en, ar = line.split('\t')
        dev_cm.append(cm)
        dev_en.append(en)
        dev_ar.append(ar)
    
dev_en = ''


In [11]:
aya_acts

{31: [37,
  76,
  81,
  84,
  128,
  130,
  167,
  170,
  180,
  184,
  207,
  250,
  361,
  364,
  371,
  373,
  408,
  427,
  437,
  439,
  482,
  506,
  519,
  528,
  531,
  544,
  559,
  586,
  597,
  625,
  646,
  652,
  669,
  699,
  710,
  716,
  765,
  785,
  789,
  834,
  845,
  879,
  886,
  893,
  905,
  909,
  931,
  948,
  962,
  987,
  991,
  1011,
  1033,
  1043,
  1070,
  1091,
  1099,
  1107,
  1109,
  1148,
  1161,
  1209,
  1231,
  1258,
  1261,
  1269,
  1272,
  1299,
  1326,
  1328,
  1330,
  1395,
  1401,
  1410,
  1423,
  1433,
  1434,
  1447,
  1476,
  1504,
  1518,
  1522,
  1523,
  1526,
  1528,
  1543,
  1567,
  1587,
  1602,
  1607,
  1615,
  1617,
  1656,
  1790,
  1833,
  1846,
  1864,
  1865,
  1866,
  1871,
  1874,
  1876,
  1885,
  1888,
  1903,
  1909,
  1945,
  1979,
  1980,
  1999,
  2005,
  2021,
  2027,
  2029,
  2032,
  2033,
  2034,
  2080,
  2081,
  2127,
  2170,
  2171,
  2178,
  2193,
  2198,
  2204,
  2208,
  2212,
  2215,
  2217,
  2244,
  2

In [13]:
aya_rand_acts = load_from_pkl('random-neurons-aya-1.pkl')

In [14]:
aya_rand_acts_2 = load_from_pkl('random-neurons-aya-2.pkl')
aya_rand_acts_3 = load_from_pkl('random-neurons-aya-3.pkl')
aya_rand_acts_4 = load_from_pkl('random-neurons-aya-4.pkl')
aya_rand_acts_5 = load_from_pkl('random-neurons-aya-5.pkl')

In [12]:
aya_rand_acts

{25: [303,
  409,
  599,
  758,
  1651,
  1716,
  1957,
  2008,
  2254,
  2273,
  2429,
  3431,
  3634,
  4309,
  4965,
  5109,
  5250,
  5321,
  5334,
  5341,
  5530,
  5764,
  5852,
  6325,
  6633,
  6917,
  6991,
  7028,
  7390,
  7769,
  7810,
  7918,
  7922,
  7977,
  8083,
  8321,
  8345,
  8366,
  8420,
  8594,
  8708,
  8918,
  9370,
  9481,
  9560,
  9563,
  9651,
  9808,
  10132,
  10722,
  10748,
  11136,
  11268,
  11281,
  11291,
  11297,
  11523,
  12371,
  12451,
  12546,
  12699,
  13079,
  13215,
  13269,
  13311,
  13407,
  13440,
  14142,
  14313],
 31: [104,
  383,
  480,
  884,
  950,
  953,
  1421,
  1625,
  1778,
  1967,
  2507,
  2695,
  2700,
  2774,
  2787,
  2937,
  2951,
  2991,
  3113,
  3139,
  3262,
  3264,
  3560,
  3577,
  3593,
  3664,
  3671,
  3994,
  4052,
  4759,
  4815,
  5037,
  5186,
  5272,
  5325,
  5725,
  6091,
  6133,
  6418,
  6921,
  7098,
  7209,
  7689,
  7717,
  7726,
  7862,
  7958,
  8004,
  8198,
  8526,
  8529,
  8556,
  8589,
  87

In [14]:
cm_acts

{'bloom-7b1': [{'name': 'L7N10392',
   'layer': 7,
   'neuron_pos': 10392,
   'act_diff': 1.4453125},
  {'name': 'L7N10589', 'layer': 7, 'neuron_pos': 10589, 'act_diff': 1.4375},
  {'name': 'L7N1494', 'layer': 7, 'neuron_pos': 1494, 'act_diff': 1.4375},
  {'name': 'L7N9154', 'layer': 7, 'neuron_pos': 9154, 'act_diff': 1.4375},
  {'name': 'L7N13654', 'layer': 7, 'neuron_pos': 13654, 'act_diff': 1.4375},
  {'name': 'L7N3415', 'layer': 7, 'neuron_pos': 3415, 'act_diff': 1.4375},
  {'name': 'L7N15126', 'layer': 7, 'neuron_pos': 15126, 'act_diff': 1.4375},
  {'name': 'L7N11942', 'layer': 7, 'neuron_pos': 11942, 'act_diff': 1.34375},
  {'name': 'L7N12885',
   'layer': 7,
   'neuron_pos': 12885,
   'act_diff': 0.80859375},
  {'name': 'L7N4188', 'layer': 7, 'neuron_pos': 4188, 'act_diff': 0.78125},
  {'name': 'L26N8694', 'layer': 26, 'neuron_pos': 8694, 'act_diff': 0.7734375},
  {'name': 'L25N5749',
   'layer': 25,
   'neuron_pos': 5749,
   'act_diff': 0.73046875},
  {'name': 'L25N14719',
   '

In [15]:
import pickle

test = pickle.load(open('natural_codemixed_sensitive_neurons.pkl', 'rb'))
test['aya-expanse-8b']

[{'name': 'L31N10126', 'layer': 31, 'neuron_pos': 10126, 'act_diff': 0.359375},
 {'name': 'L31N12469',
  'layer': 31,
  'neuron_pos': 12469,
  'act_diff': 0.357421875},
 {'name': 'L31N5388', 'layer': 31, 'neuron_pos': 5388, 'act_diff': 0.31640625},
 {'name': 'L30N7357', 'layer': 30, 'neuron_pos': 7357, 'act_diff': 0.29296875},
 {'name': 'L31N13289',
  'layer': 31,
  'neuron_pos': 13289,
  'act_diff': 0.28515625},
 {'name': 'L30N6571', 'layer': 30, 'neuron_pos': 6571, 'act_diff': 0.28125},
 {'name': 'L28N11619',
  'layer': 28,
  'neuron_pos': 11619,
  'act_diff': 0.267578125},
 {'name': 'L29N5667',
  'layer': 29,
  'neuron_pos': 5667,
  'act_diff': 0.263671875},
 {'name': 'L28N13660',
  'layer': 28,
  'neuron_pos': 13660,
  'act_diff': 0.26171875},
 {'name': 'L30N4935',
  'layer': 30,
  'neuron_pos': 4935,
  'act_diff': 0.251953125},
 {'name': 'L30N10487', 'layer': 30, 'neuron_pos': 10487, 'act_diff': 0.25},
 {'name': 'L21N819', 'layer': 21, 'neuron_pos': 819, 'act_diff': 0.248046875},


In [13]:
import torch 
import torch.nn.functional as F

input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)
output = F.cosine_similarity(input1, input2)
print(output)



tensor([-0.1000, -0.0861, -0.0484, -0.1040, -0.1087, -0.0555,  0.0200,  0.0161,
         0.0433,  0.0505, -0.0316, -0.1503, -0.1502,  0.1159,  0.0196,  0.0435,
        -0.1118,  0.0549, -0.1021,  0.1683, -0.0976, -0.1095, -0.0414, -0.0363,
        -0.0652,  0.0498, -0.0747, -0.0784,  0.1673,  0.0996,  0.0398, -0.0192,
        -0.0586, -0.0555,  0.1133, -0.0736, -0.0192,  0.1258, -0.0020,  0.0339,
         0.0877, -0.0818, -0.1489,  0.0351, -0.0439, -0.0316,  0.0475, -0.2202,
         0.0661, -0.1296, -0.1391, -0.1539,  0.0282,  0.0441, -0.1271,  0.1069,
         0.1101, -0.1763,  0.0515,  0.0713,  0.0384, -0.0307,  0.0833,  0.0827,
        -0.1422,  0.0282, -0.1676,  0.0225, -0.1406, -0.0573, -0.0441, -0.0184,
        -0.0460, -0.0421,  0.0240, -0.0298,  0.1421, -0.0591,  0.1253, -0.0393,
        -0.0845, -0.0141,  0.0429, -0.1959, -0.0536,  0.0082, -0.0592, -0.0091,
        -0.0010,  0.1158,  0.0450,  0.0077,  0.0735,  0.0038, -0.1325, -0.0625,
        -0.0367,  0.0282, -0.0558, -0.09

In [12]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Thu Nov 28 19:02:02 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:41:00.0 Off |                  Off |
| 30%   36C    P8              28W / 450W |   2664MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [12]:
!nvidia-smi

Thu Nov 28 18:53:44 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


|   0  NVIDIA GeForce RTX 4090        On  | 00000000:41:00.0 Off |                  Off |
|  0%   45C    P8              28W / 450W |  13484MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|    0   N/A  N/A   2828295      G   /usr/lib/xorg/Xorg                           44MiB |
|    0   N/A  N/A   4156556      C   ....conda/envs/cmix_neurons/bin/python    13424MiB |
+---------

In [16]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = 'CohereForAI/aya-expanse-8b'
access_token = "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"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", token=access_token, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [17]:
cm_file = 'datasets/calcs_english-hinglish/dev_natural_cmix.txt'
mono_file_1 = 'datasets/calcs_english-hinglish/dev_en.txt'

cm_inputs = []
mono_inputs = []

with open(cm_file, 'r') as f:
    lines = f.readlines()
    for line in lines:
        cm_inputs.append(line.lower().strip())

with open(mono_file_1, 'r') as f:
    lines = f.readlines()
    for line in lines:
        mono_inputs.append(line.lower().strip())



In [18]:
batch_size = 4
aya_natural_codemixed_semantic_clean = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size)
aya_natural_codemixed_semantic_cm_en_hi_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_cm_en_hi_acts)
aya_natural_codemixed_semantic_cm_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_cm_acts)
aya_natural_codemixed_semantic_en_hi_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_en_hi_acts)
aya_natural_codemixed_semantic_rand_1_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_rand_acts)
aya_natural_codemixed_semantic_rand_2_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_rand_acts_2)
aya_natural_codemixed_semantic_rand_3_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_rand_acts_3)
aya_natural_codemixed_semantic_rand_4_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_rand_acts_4)
aya_natural_codemixed_semantic_rand_5_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_rand_acts_5)

100%|██████████| 236/236 [02:51<00:00,  1.38it/s]
100%|██████████| 236/236 [02:54<00:00,  1.35it/s]
100%|██████████| 236/236 [02:54<00:00,  1.35it/s]
100%|██████████| 236/236 [02:54<00:00,  1.35it/s]
100%|██████████| 236/236 [02:54<00:00,  1.35it/s]
100%|██████████| 236/236 [02:54<00:00,  1.35it/s]
100%|██████████| 236/236 [02:54<00:00,  1.35it/s]
100%|██████████| 236/236 [02:54<00:00,  1.35it/s]
100%|██████████| 236/236 [02:54<00:00,  1.35it/s]


In [19]:
sum(aya_natural_codemixed_semantic_clean)/len(aya_natural_codemixed_semantic_clean)

0.48615292429670914

In [20]:
sum(aya_natural_codemixed_semantic_cm_en_hi_deact)/len(aya_natural_codemixed_semantic_cm_en_hi_deact)

0.4210326889264862

In [21]:
sum(aya_natural_codemixed_semantic_cm_deact)/len(aya_natural_codemixed_semantic_cm_deact)

0.44179138054007433

In [22]:
sum(aya_natural_codemixed_semantic_en_hi_deact)/len(aya_natural_codemixed_semantic_en_hi_deact)

0.3606444275809448

In [23]:
sum(aya_natural_codemixed_semantic_rand_1_deact)/len(aya_natural_codemixed_semantic_rand_1_deact)

0.4818364040107484

In [24]:
sum(aya_natural_codemixed_semantic_rand_2_deact)/len(aya_natural_codemixed_semantic_rand_2_deact)

0.4917129548998142

In [25]:
sum(aya_natural_codemixed_semantic_rand_3_deact)/len(aya_natural_codemixed_semantic_rand_3_deact)

0.48457534044917727

In [26]:
sum(aya_natural_codemixed_semantic_rand_4_deact)/len(aya_natural_codemixed_semantic_rand_4_deact)

0.48472591975185775

In [27]:
sum(aya_natural_codemixed_semantic_rand_5_deact)/len(aya_natural_codemixed_semantic_rand_5_deact)

0.4871522939888535

In [22]:
sum(aya_natural_codemixed_semantic_rand_deact_5)/len(aya_natural_codemixed_semantic_rand_deact_5)

0.04883227176220807

In [None]:
aya_natural_codemixed_semantic_deact = collect_semantic_sim(model, tokenizer, mono_inputs, cm_inputs, batch_size, aya_acts)


In [19]:
aya_natural_codemixed_semantic_deact

[1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0

In [4]:
from IPython.core.debugger import Pdb

def hello_world():
    a = 3
    b = 2
    print(a+b)


In [5]:
hello_world()

3
3
3
3
3


KeyboardInterrupt: 

In [15]:
model

CohereForCausalLM(
  (model): CohereModel(
    (embed_tokens): Embedding(256000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x CohereDecoderLayer(
        (self_attn): CohereSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): CohereRotaryEmbedding()
        )
        (mlp): CohereMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): CohereLayerNorm()
      )
    )
    (norm): CohereLayerNorm()
    (rotary_emb): CohereRotaryEmbedding

In [20]:
with open('aya_natural_codemixed_semantic_clean.pkl', 'rb') as f:
    baseline = pickle.load(f)

In [21]:
baseline

[1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 1.0,
 1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 1.0,
 -1.0,
 -1.0,
 1.0,
 

In [23]:
sum(aya_natural_codemixed_semantic_deact)/len(aya_natural_codemixed_semantic_deact)

0.027600849256900213

In [24]:
sum(baseline)/len(baseline)

0.059447983014861996

In [14]:
#with open('aya_natural_codemixed_semantic_clean.pkl', 'wb') as f:
#    pickle.dump(aya_natural_codemixed_semantic_clean, f)

with open('aya_natural_codemixed_semantic_deact.pkl', 'wb') as f:
    pickle.dump(aya_natural_codemixed_semantic_deact, f)

In [15]:
with open('aya_natural_codemixed_semantic_deact_cm_en_hi.pkl', 'wb') as f:
    pickle.dump(aya_natural_codemixed_semantic_deact, f)

In [1]:
!nvidia-smi

Thu Nov 28 18:47:42 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:41:00.0 Off |                  Off |
|  0%   44C    P8              31W / 450W |     54MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [10]:
all_cos_sims = collect_semantic_sim(model, tokenizer, cm_inputs, mono_inputs, 16)

  return torch.tensor(last_indices).unsqueeze(-1)


tensor([[0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])


  5%|▌         | 3/59 [00:00<00:09,  5.85it/s]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 

 12%|█▏        | 7/59 [00:00<00:04, 11.39it/s]

tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

 20%|██        | 12/59 [00:01<00:02, 16.26it/s]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

 31%|███       | 18/59 [00:01<00:02, 19.08it/s]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
      

 36%|███▌      | 21/59 [00:01<00:02, 17.15it/s]

tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0

 41%|████      | 24/59 [00:01<00:01, 17.83it/s]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 

 51%|█████     | 30/59 [00:01<00:01, 19.42it/s]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         

 61%|██████    | 36/59 [00:02<00:01, 20.62it/s]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

 66%|██████▌   | 39/59 [00:02<00:00, 20.75it/s]

tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

 71%|███████   | 42/59 [00:02<00:00, 17.69it/s]

tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

 80%|███████▉  | 47/59 [00:02<00:00, 17.61it/s]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 

 85%|████████▍ | 50/59 [00:03<00:00, 18.14it/s]

tensor([[0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])
tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1

 93%|█████████▎| 55/59 [00:03<00:00, 18.87it/s]

tensor([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],


100%|██████████| 59/59 [00:03<00:00, 16.85it/s]


tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1],
        [0, 0, 0, 0, 0, 0, 0, 

In [11]:
len(all_cos_sims)

942

In [11]:
all_cos_sims

[tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(-1.),
 tensor(-1.),
 tensor(1.),
 tensor(-1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(-1.),
 tensor(-1.),
 tensor(1.),
 tensor(-1.),
 tensor(-1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(1.),
 tensor(-1.),


In [35]:
del model, tokenizer
torch.cuda.empty_cache()


In [None]:
all_cos_sims = ollect_semantic_sim(model, tokenizer, inputs_1, inputs_2, batch_size)

In [13]:
masks = [
        torch.tensor([
            [1, 1, 1, 0, 0],
            [1, 1, 0, 0, 0],
            [1, 0, 0, 0, 0],
            [0, 1, 1, 0, 1],
            [0, 0, 0, 0, 0]
        ]),
        torch.tensor([
            [1, 1, 1, 1, 1],
            [1, 1, 1, 0, 0],
            [1, 0, 0, 0, 0],
            [0, 0, 0, 0, 0]
        ])
    ]
print("\nMethod 1 (Argmax of Reversed):")
print(method_argmax(masks[0]))


Method 1 (Argmax of Reversed):
tensor([2, 1, 0, 4, 4])
tensor([2, 1, 0, 4, 0])


In [3]:
import torch

batch_size = 4
seq_len = 10
dim = 5
x = torch.randn(batch_size, seq_len, dim)

positions = torch.tensor([2, 5, 7, 3]).unsqueeze(1) # batch*1*dim
selected_2 = torch.gather(x, 1, positions.unsqueeze(2).expand(-1, -1, dim))
print("\nMethod 2 (Using torch.gather):")
print(selected_2.shape)  # Should be (batch_size, 1, dim)
print(selected_2.squeeze(1))


Method 2 (Using torch.gather):
torch.Size([4, 1, 5])
tensor([[ 1.6147, -0.0753, -0.8926,  0.1688, -1.3310],
        [-0.9422,  0.9169, -0.8206, -1.0130,  1.5415],
        [ 1.1252, -0.4453,  0.0153,  0.9160,  0.3917],
        [-0.1650,  0.9166, -0.4730, -0.8105, -0.8646]])
