In [None]:
! pip install datasets transformers

Collecting datasets
  Downloading datasets-1.11.0-py3-none-any.whl (264 kB)
[K     |████████████████████████████████| 264 kB 14.4 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.9.2-py3-none-any.whl (2.6 MB)
[K     |████████████████████████████████| 2.6 MB 73.6 MB/s 
Collecting tqdm>=4.42
  Downloading tqdm-4.62.0-py2.py3-none-any.whl (76 kB)
[K     |████████████████████████████████| 76 kB 5.3 MB/s 
Collecting huggingface-hub<0.1.0
  Downloading huggingface_hub-0.0.15-py3-none-any.whl (43 kB)
[K     |████████████████████████████████| 43 kB 1.9 MB/s 
[?25hCollecting fsspec>=2021.05.0
  Downloading fsspec-2021.7.0-py3-none-any.whl (118 kB)
[K     |████████████████████████████████| 118 kB 84.3 MB/s 
[?25hCollecting xxhash
  Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)
[K     |████████████████████████████████| 243 kB 65.1 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |█████████████

In [None]:
from transformers import AutoTokenizer, BertForMaskedLM
from torch.nn import functional as F
import torch
import numpy as np
from numpy import savetxt, loadtxt
import pandas as pd
import random
from datetime import datetime
import errno
import os
from google.colab import files

# Experiment Settings

In [None]:
model_name = 'xlnet-large-cased'  # 'bert-large-uncased', 'gpt2', 'xlm-roberta-large', 'xlnet-large-cased', ('roberta-large')

num_prime_tokens = 2 
occur_prime_trigrams = 4 # Each unique priming trigram occurs x times in prime input
prime_pause_str = '.' # (Punctuation) token (string) that seperates trigrams in prime input

priming_AAB = True
priming_ABA = True
priming_ABB = True

num_probe_tokens = 4
"""This version works only if all probing patterns are True"""
probing_AAB = True
probing_ABA = True
probing_ABB = True
probing_ABCa = True
probing_ABCb = True

"""Not implemented - needs to be set to True"""
real_probes = True

experiment_cycles = 16*16

custom_primes = False # Number of trigrams = num_prime_tokens**2 * occur_prime_trigrams
custom_probes = False # Number of trigrams = num_probe_tokens**2 
probes_eql_primes = False
upload_required = True # Set to False if PPMI index is already available

num_probe = num_probe_tokens**2

# Custom Primes or Probes: Upload PMI Index

In [None]:
if upload_required and (custom_primes or custom_probes):
  prime_positions = np.random.permutation(np.arange(num_probe*2))
  probe_positions = prime_positions[:num_probe]
  prime_positions = prime_positions[num_probe:]
  print('Please upload PMI CSV file for ', model_name)
  custom_CSV = list(files.upload())[0]
  pmi_index = pd.read_csv(custom_CSV, index_col=0)

# Run Experiment

In [None]:
metainfo = ''

if probing_ABCa and probing_ABCb:
  probing_ABCb = False

tokenizer = AutoTokenizer.from_pretrained(model_name)

if model_name.lower().find('gpt2') != -1:
  from transformers import GPT2LMHeadModel
  model = GPT2LMHeadModel.from_pretrained(model_name)
  special_tokens = np.array([None, 50256, 
                           tokenizer.convert_tokens_to_ids(prime_pause_str), None])# CLS (None), MASK (<|endoftext|>), pause, SEP (None)
elif model_name.lower().find('xlnet') != -1:
  xlnet = True
  from transformers import XLNetLMHeadModel
  model = XLNetLMHeadModel.from_pretrained(model_name)
  special_tokens = np.array([17, tokenizer.mask_token_id, 
                           tokenizer.convert_tokens_to_ids(prime_pause_str), tokenizer.cls_token_id])# 17 (Start), MASK, pause, CLS
else: 
  model = BertForMaskedLM.from_pretrained(model_name)
  special_tokens = np.array([tokenizer.cls_token_id, tokenizer.mask_token_id, 
                           tokenizer.convert_tokens_to_ids(prime_pause_str), tokenizer.sep_token_id])# CLS, MASK, pause, SEP

pauses = np.repeat(special_tokens[2], num_prime_tokens**2*occur_prime_trigrams)

def token_selector(num_tkns, ids_xcl=[]):
  ids = random.sample(set(range(len(tokenizer.vocab.keys())))-set(ids_xcl), num_tkns)
  return ids

def priming_probs(prime_inpt, fltr, p_pstn=-1):
  cls_v = np.tile(special_tokens[0], (num_probe, 1)).flatten()
  sep_v = np.tile(special_tokens[-1], (num_probe, 1)).flatten()
  #print(prime_inpt, sep_v, sep='\n')
  if model_name.lower().find('gpt2') == -1: # <|endoftext|> not required for gpt2
    if model_name.lower().find('xlnet') != -1: 
      clsX_v = np.tile(4, (num_probe, 1)).flatten() # cls token is used differently in xlnet
      prime_inpt = np.insert(np.roll(np.insert(np.roll(np.insert(prime_inpt,0,clsX_v,axis=1),-1),0,sep_v,axis=1),-1),0,cls_v,axis=1)
      p_pstn=-3
    else: 
      prime_inpt = np.roll(np.insert(np.insert(prime_inpt,0,cls_v,axis=1),0,sep_v,axis=1),-1)
      p_pstn=-2
  #print(prime_inpt)
  #print(fltr)
  prime_inpt = torch.tensor(prime_inpt)
  prime_inpt = prime_inpt.to('cuda')
  model.to('cuda')
  logits = model(prime_inpt)[0]
  logits.to('cuda')
  softmax = F.softmax(logits, dim = -1)
  softmax.to('cuda')
  softmax = softmax[:,p_pstn].cpu().detach().numpy()
  return softmax[:,tuple(fltr)]

df_AAB_priming = pd.DataFrame([])
df_ABA_priming = pd.DataFrame([])
df_ABB_priming = pd.DataFrame([])

for i_exp in range(experiment_cycles):
  probabilities = np.full((4,6,num_probe),None).astype('float32')

  prime_AAB = prime_ABA = prime_ABB = np.array([]).astype('int64')

  if custom_primes: 
    if priming_AAB:   
      prime_AAB = pmi_index[pmi_index['Pattern'].str.contains('AAB')][:num_probe*2].reset_index(drop=True).filter(
          items=['Position 1', 'Position 2','Position 3']).to_numpy()[(prime_positions),:]
      prime_AAB = np.roll(np.insert(prime_AAB,0,pauses,axis=1),-1) # Insert pauses

    if priming_ABA:
      prime_ABA = pmi_index[pmi_index['Pattern'].str.contains('ABA')][:num_probe*2].reset_index(drop=True).filter(
          items=['Position 1', 'Position 2','Position 3']).to_numpy()[(prime_positions),:]
      prime_ABA = np.roll(np.insert(prime_ABA,0,pauses,axis=1),-1) # Insert pauses
    
    if priming_ABB:
      prime_ABB = pmi_index[pmi_index['Pattern'].str.contains('ABB')][:num_probe*2].reset_index(drop=True).filter(
          items=['Position 1', 'Position 2','Position 3']).to_numpy()[(prime_positions),:]
      prime_ABB = np.roll(np.insert(prime_ABB,0,pauses,axis=1),-1) # Insert pauses

  else: # random primes
      prime_order = np.repeat(np.arange(num_prime_tokens**2),occur_prime_trigrams)
      np.random.shuffle(prime_order)

      ids_prime_a = token_selector(num_prime_tokens,special_tokens)
      ids_prime_b = token_selector(num_prime_tokens,np.concatenate((special_tokens, ids_prime_a), axis=0))

      combs_AB = np.transpose((np.repeat(ids_prime_a, len(ids_prime_b)), np.tile(ids_prime_b, len(ids_prime_a))))

      if priming_AAB: prime_AAB = np.insert(combs_AB[prime_order][:,(0,0,1)],3,pauses,axis=1)
      if priming_ABA: prime_ABA = np.insert(combs_AB[prime_order][:,(0,1,0)],3,pauses,axis=1)
      if priming_ABB: prime_ABB = np.insert(combs_AB[prime_order][:,(0,1,1)],3,pauses,axis=1)

  used_tokens = np.unique(np.append(np.append(prime_AAB, prime_ABA),prime_ABB))

  probe_AAB = probe_ABA = probe_ABB = probe_ABC = np.array([]).astype('int64')

  if probes_eql_primes:
    if priming_AAB: 
      probe_AAB = prime_AAB[:,:3]
    if priming_ABA: 
      probe_ABA = prime_ABA[:,:3]
    if priming_ABB: 
      probe_ABB = prime_ABB[:,:3]
    probing_ABCa = False
    probing_ABCb = False

  elif custom_probes:
    if real_probes:
      probing_ABCa = False
      probing_ABCb = False

      if probing_AAB: probe_AAB = pmi_index[pmi_index['Pattern'].str.contains('AAB')][:num_probe*2].reset_index(drop=True).filter(
          items=['Position 1', 'Position 2','Position 3']).to_numpy()[(probe_positions),:]
      if probing_ABA: probe_ABA = pmi_index[pmi_index['Pattern'].str.contains('ABA')][:num_probe*2].reset_index(drop=True).filter(
          items=['Position 1', 'Position 2','Position 3']).to_numpy()[(probe_positions),:]
      if probing_ABB: probe_ABB = pmi_index[pmi_index['Pattern'].str.contains('ABB')][:num_probe*2].reset_index(drop=True).filter(
          items=['Position 1', 'Position 2','Position 3']).to_numpy()[(probe_positions),:]

    else: print('Fake probe condition not implemented yet')

  
  else: # random probes
    ids_probe_a = token_selector(num_probe_tokens,np.concatenate((special_tokens, used_tokens), axis=0))
    ids_probe_b = token_selector(num_probe_tokens,np.concatenate((special_tokens, used_tokens), axis=0))
    
    combs_AB = np.transpose((np.repeat(ids_probe_a, len(ids_probe_b)), np.tile(ids_probe_b, len(ids_probe_a))))

    if probing_AAB: probe_AAB = combs_AB[:,(0,0,1)]
    if probing_ABA: probe_ABA = combs_AB[:,(0,1,0)]
    if probing_ABB: probe_ABB = combs_AB[:,(0,1,1)]
    if probing_ABCa: probe_ABC = np.roll(np.insert(combs_AB,0,np.roll(combs_AB[:,0],-num_probe_tokens),axis=1),-1,axis=1)
    if probing_ABCb: probe_ABC = np.roll(np.insert(combs_AB,0,np.roll(combs_AB[:,1],-1),axis=1),-1,axis=1)

  metainfo += str(f'\n---------\nCycle: {i_exp}\n---------\n')
  metainfo += str(f'Prime AAB:\n{prime_AAB}\n\nProbe AAB:\n{probe_AAB}\n\n__________\n')
  metainfo += str(f'Prime ABA:\n{prime_ABA}\n\nProbe ABA:\n{probe_ABA}\n\n__________\n')
  metainfo += str(f'Prime ABB:\n{prime_ABB}\n\nProbe ABB:\n{probe_ABB}\n\n__________\n')
  metainfo += str(f'Probe ABC:\n{probe_ABC}')

  mask_v = np.repeat(special_tokens[1],num_probe)

  # Normalization probs (without priming)
  if probing_AAB: 
    probabilities[0,0] = np.diag(priming_probs(np.insert( probe_AAB ,1,mask_v,axis=1)[:,:2], probe_AAB[:,1] )) # Pos 2 (Aa) norm
    probabilities[0,2] = np.diag(priming_probs(np.insert( probe_AAB ,2,mask_v,axis=1)[:,:3], probe_AAB[:,2] )) # Pos 3 (AAb) norm
    
  if probing_ABA or probing_ABB or probing_ABCa or probing_ABCb:
    probabilities[0,1] = np.diag(priming_probs(np.insert( probe_ABA ,1,mask_v,axis=1)[:,:2], probe_ABA[:,1] )) # Pos 2 (Ab) norm
    if probing_ABA: probabilities[0,3] = np.diag(priming_probs(np.insert( probe_ABA ,2,mask_v,axis=1)[:,:3], probe_ABA[:,2] )) # Pos 3 (ABa) norm
    if probing_ABB: probabilities[0,4] = np.diag(priming_probs(np.insert( probe_ABB ,2,mask_v,axis=1)[:,:3], probe_ABB[:,2] )) # Pos 3 (ABb) norm
    if probing_ABCa or probing_ABCb: 
      probabilities[0,5] = np.diag(priming_probs(np.insert( probe_ABC ,2,mask_v,axis=1)[:,:3], probe_ABC[:,2] )) # Pos 3 (ABc) norm
  
  
  # Primed probs
  all_primes = np.array([prime_AAB, prime_ABA, prime_ABB])
  for prime_i in range(1,len(all_primes)+1):
    
    if len(all_primes[prime_i-1]) != 0:
      prime_temp = np.tile(all_primes[prime_i-1].flatten(),(num_probe,1))

      if probing_AAB: 
        probabilities[prime_i,0] = np.diag(priming_probs(np.insert(np.roll(np.insert(prime_temp,0,mask_v,axis=1),-1,axis=1),-1, probe_AAB[:,0] ,axis=1), probe_AAB[:,1] )) # Pos 2 (Aa) prime_temp
        probabilities[prime_i,2] = np.diag(priming_probs(np.insert(np.insert(
            np.roll(np.insert(prime_temp,0,mask_v,axis=1),-1,axis=1),-1, probe_AAB[:,0] ,axis=1),-1, probe_AAB[:,1] ,axis=1), probe_AAB[:,2] )) # Pos 3 (AAb) prime_temp

      if probing_ABA or probing_ABB or probing_ABCa or probing_ABCb:
        probabilities[prime_i,1] = np.diag(priming_probs(np.insert(np.roll(np.insert(prime_temp,0,mask_v,axis=1),-1,axis=1),-1, probe_ABA[:,0] ,axis=1), probe_ABA[:,1] )) # Pos 2 (Ab) prime_temp
        if probing_ABA: probabilities[prime_i,3] = np.diag(priming_probs(np.insert(np.insert(
            np.roll(np.insert(prime_temp,0,mask_v,axis=1),-1,axis=1),-1, probe_ABA[:,0] ,axis=1),-1, probe_ABA[:,1] ,axis=1), probe_ABA[:,2] )) # Pos 3 (ABa) prime_temp
        if probing_ABB: probabilities[prime_i,4] = np.diag(priming_probs(np.insert(np.insert(
            np.roll(np.insert(prime_temp,0,mask_v,axis=1),-1,axis=1),-1, probe_ABB[:,0] ,axis=1),-1, probe_ABB[:,1] ,axis=1), probe_ABB[:,2] )) # Pos 3 (ABb) prime_temp
        if probing_ABCa or probing_ABCb:
          probabilities[prime_i,5] = np.diag(priming_probs(np.insert(np.insert(
            np.roll(np.insert(prime_temp,0,mask_v,axis=1),-1,axis=1),-1, probe_ABC[:,0] ,axis=1),-1, probe_ABC[:,1] ,axis=1), probe_ABC[:,2] )) # Pos 3 (ABc) prime_temp

    df_temp = pd.DataFrame(data={'P(C|prime,A,B)': probabilities[prime_i,5]}) # primed values...
    df_temp.insert(loc=0, column='P(B|prime,A,B)', value=probabilities[prime_i,4])
    df_temp.insert(loc=0, column='P(A|prime,A,B)', value=probabilities[prime_i,3])
    df_temp.insert(loc=0, column='P(B|prime,A,A)', value=probabilities[prime_i,2])
    df_temp.insert(loc=0, column='P(B|prime,A)', value=probabilities[prime_i,1])
    df_temp.insert(loc=0, column='P(A|prime,A)', value=probabilities[prime_i,0])
    df_temp.insert(loc=0, column='P(C|A,B)', value=probabilities[0,5]) # norm values...
    df_temp.insert(loc=0, column='P(B|A,B)', value=probabilities[0,4])
    df_temp.insert(loc=0, column='P(A|A,B)', value=probabilities[0,3])
    df_temp.insert(loc=0, column='P(B|A,A)', value=probabilities[0,2])
    df_temp.insert(loc=0, column='P(B|A)', value=probabilities[0,1])
    df_temp.insert(loc=0, column='P(A|A)', value=probabilities[0,0])
    if probing_ABCa or probing_ABCb: df_temp.insert(loc=0, column='C', value=probe_ABC[:,2]) # Token information...
    else: df_temp.insert(loc=0, column='C', value=np.full((num_probe,),None))
    df_temp.insert(loc=0, column='B', value=probe_ABA[:,1]) # CONDITIONS required if not all probes are selected!!!
    df_temp.insert(loc=0, column='A', value=probe_ABA[:,0]) # CONDITIONS required if not all probes are selected!!!
    df_temp.insert(loc=0, column='P(ABC|prime)', value= # Results...
                   np.divide(probabilities[prime_i,1],probabilities[0,1])*np.divide(probabilities[prime_i,5],probabilities[0,5]))
    df_temp.insert(loc=0, column='P(ABB|prime)', value=
                   np.divide(probabilities[prime_i,1],probabilities[0,1])*np.divide(probabilities[prime_i,4],probabilities[0,4]))
    df_temp.insert(loc=0, column='P(ABA|prime)', value=
                   np.divide(probabilities[prime_i,1],probabilities[0,1])*np.divide(probabilities[prime_i,3],probabilities[0,3]))
    df_temp.insert(loc=0, column='P(AAB|prime)', value=
                   np.divide(probabilities[prime_i,0],probabilities[0,0])*np.divide(probabilities[prime_i,2],probabilities[0,2]))

    if prime_i == 1: df_AAB_priming = df_AAB_priming.append(df_temp,ignore_index = True)
    if prime_i == 2: df_ABA_priming = df_ABA_priming.append(df_temp,ignore_index = True)
    if prime_i == 3: df_ABB_priming = df_ABB_priming.append(df_temp,ignore_index = True)
    if prime_i > 3: print('Prime number inconsistency!')

settings = f'--------\nSETTINGS\n--------\nNLP Model: {model_name}\n'
settings += f'Priming conditions: \n\tAAB: {priming_AAB} \n\tABA: {priming_ABA} \n\tABB: {priming_ABB}\n'
settings += f'Probing conditions: \n\tAAB: {probing_AAB} \n\tABA: {probing_ABA} \n\tABB: {probing_ABB} \n\tABA\'(ABCa): {probing_ABCa} \n\tABB\'(ABCb): {probing_ABCb}\n'
settings += f'Custom primes (PPMI): {custom_primes}\nCustom probes (PPMI): {custom_probes}\n'
settings += f'Primes = probes: {probes_eql_primes}\n'
settings += f'Experiment cycles: {experiment_cycles}\n'

metainfo = settings + metainfo
time_st = str(datetime.now())

# Save metainfo
f = open(time_st+'_metainfo.txt','w')
f.write(metainfo)
f.close()

df_AAB_priming.to_csv(time_st+'_AAB_prime.csv')
df_ABA_priming.to_csv(time_st+'_ABA_prime.csv')
df_ABB_priming.to_csv(time_st+'_ABB_prime.csv')

# Download Files

In [None]:
files.download(time_st+'_metainfo.txt')
files.download(time_st+'_AAB_prime.csv')
files.download(time_st+'_ABA_prime.csv')
files.download(time_st+'_ABB_prime.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>