In [1]:
import csv
import os
import random
import torch

from reasoner import Reasoner
from world import World

from collections import defaultdict
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm

In [2]:
os.environ['TOKENIZERS_PARALLELISM'] = 'False'

In [3]:
random.seed(42)
torch.manual_seed(42)

world = World(concept_path = "../../induction/data/concept_senses.csv", 
              feature_path = '../../induction/data/post_annotation_data/post_annotation_all.csv', 
              matrix_path = "../../induction/data/concept_matrix.txt")
world.create()

521it [00:00, 4727.87it/s]


In [4]:
class PropertyJudge:
    def __init__(self, model_path, device='cpu'):
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.device = device
        
        self.model.to(self.device)
        self.model.eval()
        
    def tokenize(self, batch):
        return self.tokenizer(batch, padding=True, return_tensors='pt').to(self.device)
        
    def infer(self, batch):
        with torch.no_grad():
            logits = self.model(**batch).logits.detach()
            logprob = logits - logits.logsumexp(1).unsqueeze(1)

            predicted_labels = logprob.argmax(1).tolist()
        
        return predicted_labels
    
    def truth(self, batch):
        with torch.no_grad():
            logits = self.model(**batch).logits.detach()
            logprob = logits - logits.logsumexp(1).unsqueeze(1)
            logprob_true = logprob[:, 1].exp().tolist()
            
        return logprob_true

In [5]:
MODEL = 'rl-property'
PATH = f'../../induction/checkpoints/finetuned_models/{MODEL}'

In [6]:
chimp = [f"a chimp {prop}." for prop in world.features]

In [7]:
chimp_dl = DataLoader(chimp, batch_size = 32, num_workers = 16)

In [8]:
propjudge = PropertyJudge(PATH, 'cuda:0')

In [9]:
logprobs = []
sentences = []
for batch in tqdm(chimp_dl):
    sentences.extend(batch)
    encoded = propjudge.tokenize(list(batch))
    logprob = propjudge.truth(encoded)
    logprobs.extend(logprob)
    

100%|███████████████████████████████████████████████████████████████████████████| 117/117 [00:02<00:00, 45.73it/s]


In [10]:
true_idx = (torch.tensor(logprobs) >= 0.5).nonzero().squeeze(1)

In [12]:
len(true_idx)

1069

In [11]:
torch.tensor(logprobs)[true_idx]

tensor([0.6836, 0.8849, 0.5346,  ..., 0.6708, 0.6775, 0.9713])

In [13]:
sorted_probs = torch.tensor(logprobs).sort(descending = True)

In [14]:
values, idx = sorted_probs
values, idx = [x.tolist() for x in [values, idx]]

In [15]:
inferences = [(sentences[i], v) for i, v in zip(idx, values)]