In [2]:
device = 'cuda'

In [46]:
import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained('roberta-base')
model = transformers.AutoModelForMaskedLM.from_pretrained('roberta-base').to(device)

In [None]:
import names_dataset
nd = names_dataset.NameDataset()

In [119]:
from torch.utils import data


class Sentencify:
    
    def __init__(self, template):
        self.template = template
        
    def __call__(self, name):
        return self.template.format(name=name)

class Dataset(data.Dataset):
    
    def __init__(self, n=1000, offset=0, sentencify=Sentencify('{name} climbed the hill.')):
        names = nd.get_top_names(n=n, country_alpha2='US')['US']

        self.male = names['M'][offset:]
        self.female = names['F'][offset:]

        self.data = [(sentencify(name), 0) for name in self.male]
        self.data += [(sentencify(name), 1) for name in self.female]
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

dataset = Dataset(n=500)

In [72]:
import torch
from torch import nn, optim
from tqdm.auto import tqdm

model.to(device)
probe = nn.Sequential(
    nn.Linear(768, 768),
    nn.ReLU(),
    nn.Linear(768, 1),
).to(device)
optimizer = optim.AdamW(probe.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

val_size = int(.1 * len(dataset))
train_size = len(dataset) - val_size
train, val = data.random_split(dataset, (train_size, val_size))
train_loader = data.DataLoader(train, batch_size=128, shuffle=True)
val_loader = data.DataLoader(val, batch_size=128)

bad, best, state_dict = 0, float('inf'), None
for epoch in range(25):
    probe.train()
    train_loss = 0
    for sentences, targets in tqdm(train_loader, desc=f'epoch {epoch}'):
        inputs = tokenizer(list(sentences), return_tensors='pt', padding='longest').to(device)
        outputs = model(**inputs, return_dict=True, output_hidden_states=True)
        reps = outputs.hidden_states[-1][:, 1]
        predictions = probe(reps)
        loss = criterion(predictions, targets.to(device)[:, None].float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()
    train_loss /= len(loader)
    print('train', train_loss)

    probe.eval()
    val_loss = 0
    for sentences, targets in val_loader:
        inputs = tokenizer(list(sentences), return_tensors='pt', padding='longest').to(device)
        with torch.inference_mode():
            outputs = model(**inputs, return_dict=True, output_hidden_states=True)
            reps = outputs.hidden_states[-1][:, 1]
            predictions = probe(reps)
            loss = criterion(predictions, targets.to(device)[:, None].float())        
        val_loss += loss.item()
    val_loss /= len(val_loader)
    print('val', val_loss)

    if val_loss < best:
        bad = 0
        best = val_loss
        state_dict = probe.state_dict()
    else:
        bad += 1

    if bad >= 4:
        assert state_dict is not None
        probe.load_state_dict(state_dict)
        break

epoch 0:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.5752527788281441
val 0.5014075636863708


epoch 1:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.3967887759208679
val 0.38385698199272156


epoch 2:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.26316521130502224
val 0.2339557409286499


epoch 3:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.1910376138985157
val 0.16111941635608673


epoch 4:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.14066676329821348
val 0.13111811876296997


epoch 5:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.1368890656158328
val 0.11832942813634872


epoch 6:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.10570867778733373
val 0.10340951383113861


epoch 7:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.21578484401106834
val 0.09839450567960739


epoch 8:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.09078342205611989
val 0.09676236659288406


epoch 9:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.09178701508790255
val 0.10129508376121521


epoch 10:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.0905273063108325
val 0.0932174026966095


epoch 11:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.08239491772837937
val 0.09713620692491531


epoch 12:   0%|          | 0/8 [00:00<?, ?it/s]

train 0.08013808634132147
val 0.09423713386058807


In [199]:
@torch.inference_mode()
def predict(sentence, token=1):
    inputs = tokenizer([sentence], return_tensors='pt', padding='longest').to(device)
    print(tokenizer.convert_ids_to_tokens(inputs.input_ids.squeeze().tolist()))
    outputs = model(**inputs, return_dict=True, output_hidden_states=True)
    reps = outputs.hidden_states[-1][:, token]
    predictions = probe(reps)
    score = torch.sigmoid(predictions.squeeze()).item()
    return 'F' if  score > .5 else 'M', score

predict('The doctor is climbing the hill. They work hard.', token=2)

['<s>', 'The', 'Ġdoctor', 'Ġis', 'Ġclimbing', 'Ġthe', 'Ġhill', '.', 'ĠThey', 'Ġwork', 'Ġhard', '.', '</s>']


('M', 0.07878950983285904)

In [125]:
class FlippedDataset(data.Dataset):
    
    def __init__(self, dataset):
        self.data = []
        for i in range(len(dataset)):
            sentence, gender = dataset.data[i]
            pronoun = 'He' if gender else 'She'
            sentence = f'{sentence}. {pronoun} was tired.'
            self.data.append((sentence, 1 - gender))

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

flipped = FlippedDataset(Dataset(n=10000, offset=5000))

@torch.inference_mode()
def accuracy(dataset, probe=probe):
    loader = data.DataLoader(dataset, batch_size=128)
    correct = 0
    for sentences, targets in tqdm(loader):
        inputs = tokenizer(list(sentences), return_tensors='pt', padding='longest').to(device)
        outputs = model(**inputs, return_dict=True, output_hidden_states=True)
        reps = outputs.hidden_states[-1][:, 1]
        predictions = probe(reps).gt(.5).long()
        correct += predictions.eq(targets[:, None].to(device)).sum()
    return correct / len(val)

print(accuracy(flipped))

  0%|          | 0/71 [00:00<?, ?it/s]

tensor(81.4300, device='cuda:0')


# Potential source of sentences?

In [21]:
import pathlib

sentences_files = pathlib.Path('/raid/lingo/dez/code/gutenberg/data/parsed').glob('*.txt')

sentences = []
for sentences_file in sentences_files:
    with sentences_file.open('r') as handle:
        sentences += list(handle.read().split('\n'))        

In [23]:
[sentence for sentence in sentences if ' he ' in sentence]

['she listened to him until he flew away',
 'it was as if he were talking',
 'she had wondered if he would notice her',
 'she wished she could talk as he did',
 'she knew what he would think of her',
 'what an unhappy face he had',
 'no one—no one knew where he buried the key',
 'and then he turned round and stared at me',
 'i think he asked the robin questions',
 'she felt sure he would like to hear',
 'but he had not sent one',
 'perhaps—perhaps he has been thinking about it all afternoon',
 'medlock rather irritably when he arrived',
 'medlock opened the door he heard laughing and chattering',
 'roach heard his name he smiled quite leniently',
 'there he is',
 'colin looked as if he were resting luxuriously',
 'his face flushed scarlet and he sat bolt upright',
 'then suddenly he remembered something mary had said',
 'so he could reply like a sailor',
 'then he cheered up',
 'when the nurse came in he gave his orders',
 'he mun come back—that he mun',
 'there he found the loveliness