In [13]:
import csv
import random
import torch

from reasoner import Reasoner
from world import World

from collections import defaultdict
from tqdm import tqdm

In [14]:
random.seed(42)
torch.manual_seed(42)

model_lr = {
    'rl-property': 2.5074850299401197e-06,
    'bl-property': 2.508483033932136e-06,
    'axxl-property': 3.0054890219560877e-06,
}

negative_sampler = {
    'axxl-property': '../../induction/checkpoints/finetuned_models/axxl-property',
    'bl-property': '../../induction/checkpoints/finetuned_models/bl-property',
    'rl-property': '../../induction/checkpoints/finetuned_models/rl-property',
}

world = World(concept_path = "../data/concept_senses.csv", 
             feature_path = '../data/experimental splits/train_1ns.csv', 
             matrix_path = "../data/train_1ns_matrix.txt")
world.create()

521it [00:00, 6553.82it/s]


In [15]:
stimuli = []
with open('../../induction/data/osherson_three_general.csv', 'r') as f:
    reader = csv.reader(f)
    next(f)
    for line in reader:
        premise, conclusion, strength = line
#         if 'chimp' not in premise:
        premise_concepts = list(map(lambda x: x.strip(), premise.split(",")))
#         conclusion_concepts = [conclusion]
        conclusion_concepts = world.taxonomy['mammal.n.01'].descendants() + ['chimp']
#         conclusion_concepts = list(set(world.taxonomy['mammal.n.01'].descendants()) - set(premise_concepts))

        stimuli.append([premise_concepts, conclusion_concepts, strength])

In [16]:
stimuli = []
with open('../../induction/data/osherson_two_specific.csv', 'r') as f:
    reader = csv.reader(f)
    next(f)
    for line in reader:
        premise, conclusion, strength = line
#         if 'chimp' not in premise:
        premise_concepts = list(map(lambda x: x.strip(), premise.split(",")))
        conclusion_concepts = [conclusion]
#         conclusion_concepts = world.taxonomy['mammal.n.01'].descendants() + ['chimp']
#         conclusion_concepts = list(set(world.taxonomy['mammal.n.01'].descendants()) - set(premise_concepts))

        stimuli.append([premise_concepts, conclusion_concepts, strength])

In [17]:
PROPERTIES = ['can dax', 'can fep', 'is vorpal', 'is mimsy', 'has blickets', 'has feps', 'is a wug', 'is a tove']
MODELS = ['axxl-property', 'bl-property', 'rl-property']
DEVICE = 'cuda:0'

In [18]:
results = defaultdict(list)
prop = 'requires biotin for synthesizing hemoglobin'
# prop = 'can dax'


for MODEL in MODELS:
    for stimulus in tqdm(stimuli):
        premise_concepts, conclusion_concepts, strength = stimulus
        
        reasoner = Reasoner(f'../../induction/checkpoints/finetuned_models/{MODEL}',
                        learning_rate = model_lr[MODEL], 
                        lexicon = world.lexicon,
                        device = DEVICE)
        
        adaptation = [f'{world.lexicon[c].article} {prop}.' if c in world.concepts else f'a {c} {prop}.' for c in premise_concepts]
        adaptation = reasoner.tokenizer(adaptation, return_tensors='pt', padding=True)
        
#         generalization = reasoner.prepare_stimuli(conclusion_concepts, prop)
#         generalization = [f'a mammal {prop}.']
        generalization = [f'{world.lexicon[c].article} {prop}.' if c in world.concepts else f'a {c} {prop}.' for c in conclusion_concepts]
        generalization = reasoner.tokenizer(generalization, return_tensors='pt', padding=True)
        
        labels = torch.tensor([1] * len(premise_concepts))
        
        reasoner.adapt(adaptation, labels, 20, 'not')
        
        gen_logprob = reasoner.generalize(generalization)[:, 1].mean().item()
        reasoner.model.to('cpu')
        
        results[MODEL].append([gen_logprob, float(strength)])

100%|█████████████████████████████████████████████████████████████████████████████| 36/36 [01:57<00:00,  3.26s/it]
100%|█████████████████████████████████████████████████████████████████████████████| 36/36 [01:42<00:00,  2.86s/it]
100%|█████████████████████████████████████████████████████████████████████████████| 36/36 [01:49<00:00,  3.04s/it]


In [19]:
from scipy import stats

In [21]:
for MODEL in MODELS:
    score, strength = list(zip(*results[MODEL]))
    r, p = stats.spearmanr(score, strength)
    print(f"{MODEL}: r: {r:.4f}, p-value: {p:.4f}")

axxl-property: r: 0.5431, p-value: 0.0006
bl-property: r: 0.2486, p-value: 0.1437
rl-property: r: 0.5220, p-value: 0.0011


In [5]:
MODEL = 'axxl-property'
reasoner = Reasoner(f'../../induction/checkpoints/finetuned_models/{MODEL}',
                    learning_rate = model_lr[MODEL], 
                    lexicon = world.lexicon,
                    device = DEVICE)

In [6]:
# query = ['butterfly', 'sparrow', 'emu', 'ostrich', 'lion', 'airplane', 'helicopter', 'car']
query = [c for c in world.taxonomy['bird.n.01'].descendants() if c not in positive]

In [10]:
query

['budgie',
 'parakeet',
 'buzzard',
 'falcon',
 'hawk',
 'eagle',
 'owl',
 'canary',
 'magpie',
 'raven',
 'nightingale',
 'robin',
 'starling',
 'wren',
 'chicken',
 'cockerel',
 'turkey',
 'dove',
 'pigeon',
 'partridge',
 'peacock',
 'crane',
 'flamingo',
 'heron',
 'duck',
 'goose',
 'pelican',
 'penguin',
 'seagull',
 'swan',
 'hummingbird',
 'kingfisher',
 'woodpecker',
 'emu',
 'ostrich']

In [7]:
prop = 'is able to fep'
adaptation = reasoner.prepare_stimuli(positive+negative, prop)
labels = torch.tensor([1] * len(positive) + [0] * len(negative))
# labels = torch.tensor([1] * len(positive))

generalization = reasoner.prepare_stimuli(query, prop)

In [8]:
reasoner.adapt(adaptation, labels, 20)

In [17]:
(-1.0 * reasoner.generalize(generalization)[:, 1]).topk(10)

torch.return_types.topk(
values=tensor([0.6973, 0.6070, 0.4432, 0.4406, 0.4169, 0.4019, 0.3945, 0.3903, 0.3777,
        0.3736]),
indices=tensor([ 6, 30,  5, 34, 33,  1, 29,  2, 22,  4]))

In [22]:
reasoner.stopping_epoch

1

In [9]:
len([])

0

In [23]:
query[1]

'parakeet'

In [None]:
world.f