In [1]:
import json
import torch

from minicons import scorer

from torch.utils.data import DataLoader
from tqdm import tqdm

from sklearn.metrics import accuracy_score, f1_score

In [2]:
def get_anli(split = "dev"):
    labels = []
    with open(f"../data/anli/{split}-labels.lst", "r") as f:
        for line in f:
            labels.append(int(line))
            
    anli = []
    with open(f"../data/anli/{split}.jsonl", "r") as f:
        for line in f:
            data = json.loads(line)
            anli.append([data['obs2'], f"{data['obs1']} {data['hyp1']}", f"{data['obs1']} {data['hyp2']}"])
            
    return anli, labels

In [3]:
anli, labels = get_anli("test")

In [41]:
gpt = scorer.IncrementalLMScorer("gpt2-medium")

Downloading: 100%|██████████| 718/718 [00:00<00:00, 188kB/s]
Downloading: 100%|██████████| 0.99M/0.99M [00:00<00:00, 1.88MB/s]
Downloading: 100%|██████████| 446k/446k [00:00<00:00, 3.23MB/s]
Downloading: 100%|██████████| 1.29M/1.29M [00:00<00:00, 5.44MB/s]
Downloading: 100%|██████████| 1.42G/1.42G [01:38<00:00, 15.4MB/s]
Using pad_token, but it is not set yet.


In [29]:
labels[5:10]

[1, 2, 1, 2, 2]

In [28]:
anli[5:10]

[['Tom and his friend had a lot of fun working together.',
  'Tom applied for a job at a call center. His buddy Charles was already there.',
  'Tom applied for a job at a call center. Tom worked with his father.'],
 ['Bob was banned from the game for a month afterwards.',
  'Bob was playing league of legends. Some of the other players insulted Bob.',
  'Bob was playing league of legends. He called someone online an asshole.'],
 ['The hotel profit grew and the business gave Tom huge bonus pay.',
  "Tom's business decided to buy a franchised hotel. Tom's business plan made money for the hotel.",
  "Tom's business decided to buy a franchised hotel. Because of no experience, Tom was demoted to another position."],
 ['ISIS killed them.',
  'Two girls have boyfriends in isis. the boyfriends were american citizens.',
  'Two girls have boyfriends in isis. The girls become involved with ISIS.'],
 ['It turned out they never survive when caught.',
  'Beth caught two fireflies in a jar. Beth found

In [42]:
anli_dl = DataLoader(anli[5:10], batch_size = 5)

In [43]:
for batch in anli_dl:
    pass

In [44]:
obs2, hyp1, hyp2 = batch

In [45]:
hyp1_scores = []
hyp2_scores = []
hyp1_scores.extend(gpt.partial_score(list(hyp1), list(obs2), reduction=lambda x: x.sum(0).item()))
hyp2_scores.extend(gpt.partial_score(list(hyp2), list(obs2), reduction=lambda x: x.sum(0).item()))

In [46]:
predicted = (torch.stack((torch.tensor(hyp1_scores), torch.tensor(hyp2_scores))).argmax(0)+1)

In [47]:
labels[:10][:5], predicted

([2, 1, 1, 1, 2], tensor([2, 1, 1, 1, 2]))

In [48]:
list(zip(hyp1_scores, hyp2_scores))

[(-33.70368194580078, -27.90334701538086),
 (-29.258975982666016, -34.2630615234375),
 (-47.15886688232422, -52.57579040527344),
 (-13.972801208496094, -14.551834106445312),
 (-38.940528869628906, -38.845726013183594)]

In [40]:
gpt.token_score("Tom applied for a job at a call center. His buddy Charles was already there. Tom and his friend had a lot of fun working together.")

[[('Tom', 0.0),
  ('applied', -11.225250244140625),
  ('for', -1.5445404052734375),
  ('a', -1.4278068542480469),
  ('job', -1.6258468627929688),
  ('at', -1.205352783203125),
  ('a', -2.2387351989746094),
  ('call', -7.6460418701171875),
  ('center', -0.33109283447265625),
  ('.', -2.320404052734375),
  ('His', -3.696258544921875),
  ('buddy', -8.141716003417969),
  ('Charles', -7.778144836425781),
  ('was', -2.7195816040039062),
  ('already', -4.171699523925781),
  ('there', -1.3886184692382812),
  ('.', -1.276641845703125),
  ('Tom', -8.863441467285156),
  ('and', -3.7482376098632812),
  ('his', -2.6439132690429688),
  ('friend', -1.9194564819335938),
  ('had', -3.78790283203125),
  ('a', -2.5724945068359375),
  ('lot', -3.3725738525390625),
  ('of', -0.4783773422241211),
  ('fun', -1.9626235961914062),
  ('working', -2.7781906127929688),
  ('together', -2.467620849609375),
  ('.', -0.829345703125)]]

In [36]:
x = [('Tom', -2.4024810791015625),
  ('and', -3.6292877197265625),
  ('his', -0.414031982421875),
  ('friend', -5.403038024902344),
  ('had', -3.56390380859375),
  ('a', -1.6639251708984375),
  ('lot', -3.6679534912109375),
  ('of', -0.35161805152893066),
  ('fun', -1.5521011352539062),
  ('working', -2.3483734130859375),
  ('together', -1.6277923583984375),
  ('.', -0.5834808349609375)]

words, lps = list(zip(*x))

In [39]:
torch.tensor(lps).sum()

tensor(-27.2080)

In [27]:
list(zip(hyp2, obs2, hyp2_scores))

[('Tom applied for a job at a call center. Tom worked with his father.',
  'Tom and his friend had a lot of fun working together.',
  -27.20798683166504),
 ('Bob was playing league of legends. He called someone online an asshole.',
  'Bob was banned from the game for a month afterwards.',
  -35.95820999145508),
 ("Tom's business decided to buy a franchised hotel. Because of no experience, Tom was demoted to another position.",
  'The hotel profit grew and the business gave Tom huge bonus pay.',
  -58.56022644042969),
 ('Two girls have boyfriends in isis. The girls become involved with ISIS.',
  'ISIS killed them.',
  -13.746368408203125),
 ('Beth caught two fireflies in a jar. beth found the fireflies dead.',
  'It turned out they never survive when caught.',
  -39.0667724609375)]

In [32]:
hyp1_scores = []
hyp2_scores = []
for batch in tqdm(anli_dl):
    obs2, hyp1, hyp2 = batch
    hyp1_scores.extend(gpt.partial_score(list(hyp1), list(obs2)))
    hyp2_scores.extend(gpt.partial_score(list(hyp2), list(obs2)))
    

100%|██████████| 153/153 [12:50<00:00,  5.03s/it]


In [52]:
predicted = (torch.stack((torch.tensor(hyp1_scores), torch.tensor(hyp2_scores))).argmax(0)+1)

In [51]:
torch.stack((torch.tensor(hyp1_scores), torch.tensor(hyp2_scores))).argmax(0)+1

tensor([2, 1, 1, 1, 2])

In [56]:
(torch.tensor(labels) == predicted).float().mean()

tensor(0.5273)

In [48]:
len(labels) == len(predicted)

True

In [51]:
accuracy_score(labels, predicted)

0.5272965021248774

In [39]:
torch.tensor(hyp1_scores)

tensor([-2.9011, -3.5036, -3.3421,  ..., -4.0807, -3.2114, -2.7891])

In [27]:
gpt.partial_score(list(hyp1), list(obs2)), gpt.partial_score(list(hyp2), list(obs2))

([-2.789137840270996], [-3.0277607440948486])