In [None]:
!pip install transformers sentencepiece

In [None]:
from transformers import pipeline
import torch

device = "cuda:0"
# device = "cpu"
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)

In [None]:
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)

In [None]:
import json
import zipfile

arxiv_data = []
archive = zipfile.ZipFile('arxiv_dataset.zip', 'r')
for line in archive.open("arxiv-metadata-oai-snapshot.json"):
    line = json.loads(line)
    arxiv_data.append((line["title"].lower(), line["abstract"].lower()))

In [None]:
print(arxiv_data[0])

In [None]:
import random

random.seed(42)
arxiv_data = random.sample(arxiv_data, 20_000)

In [None]:
from torch.utils.data import DataLoader

candidate_labels = [
    ["inquiry", "answer"],
    ["simple", "complex"],
    ["abstract", "concrete"],
]
from transformers.pipelines.pt_utils import PipelineDataset

class ZeroshotDataset:
    def __init__(self, labels):
      self.data = arxiv_data
      self.labels = labels

    def __getitem__(self, idx):
      return self.data[idx][0]

    def __len__(self):
      return len(self.data)

loader = DataLoader(arxiv_data, batch_size=256)
try:
    with open("labels.json", "r") as label_json:
      labels = json.load(label_json)
except Exception:
    labels = []
    for i, (title, _) in enumerate(loader):
        if i % 1 == 0:
            print(f"{i}/{len(loader)} processed")
        ll = []
        for candidates in candidate_labels:
            with torch.no_grad():
                result = classifier(list(title), candidates)
                ll.extend(x["labels"][0] for x in result)
        labels.append(ll)
    with open("labels.json", "w") as out:
        json.dump(labels, out)

In [None]:
del classifier
torch.cuda.empty_cache()

In [None]:
import json

with open("labels.json", "r") as label_json:
      labels = json.load(label_json)

def split(list_a, chunk_size):
    for i in range(0, len(list_a), chunk_size):
        yield list_a[i:i + chunk_size]

chunked_labels = []
for batch in labels:
    chunks = list(split(batch, len(batch) // 3))
    chunks = list(zip(*chunks))
    chunked_labels.extend(chunks)

labels = chunked_labels

In [None]:
import re

class TitleAbstractDataset:
    def __init__(self, data, descriptors):
        self.data = data
        self.descriptors = descriptors

    def __getitem__(self, idx):
        title, abstract = self.data[idx]
        descriptors = ", ".join(self.descriptors[idx])
        return f"summarize {descriptors}: {abstract.strip()}", title

    def __len__(self):
        return len(self.data)


def collate_fn(batch):
    inputs, outputs = zip(*batch)
    encoding = tokenizer(
        inputs,
        padding="longest",
        max_length=512, # XXX
        truncation=True,
        return_tensors="pt",
    )
    input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
    target_encoding = tokenizer(
        outputs,
        padding="longest",
        max_length=512,  # XXX
        truncation=True,
        return_tensors="pt",
    )
    labels = target_encoding.input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    return input_ids, attention_mask, labels

In [None]:
print(len(labels))
print(len(arxiv_data))
print(labels[0])

In [None]:
import random

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.utils.data import DataLoader

from tqdm import tqdm

device = "cuda:0"
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
print(model)
model = model.to(device)

dataset = TitleAbstractDataset(arxiv_data, labels)
loader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=collate_fn)
optim = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(10):
  pbar = tqdm(loader)
  pbar.set_description(f"epoch {epoch + 1}")
  loss_ema = None
  for i, (input_ids, attention_mask, _labels) in enumerate(pbar):
      if (i + 1) % 4 == 0:
        optim.step()
        optim.zero_grad()
      model.train()
      input_ids = input_ids.to(device)
      attention_mask = attention_mask.to(device)
      _labels = _labels.to(device)
      out = model(input_ids=input_ids, attention_mask=attention_mask, labels=_labels)
      loss = out.loss
      loss.backward()
      if loss_ema is None:
          loss_ema = loss.item()
      else:
          loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
      pbar.set_postfix_str(f"loss = {loss_ema:.3f}")
  
      if i % 500 == 0:
        torch.save(model.state_dict(), f"t5_small_ft_{epoch+1}.pth")
        model.eval()
        samples = random.sample(range(len(dataset)), k=min(len(dataset), 10))
        for sample in samples:
            abstract, title = dataset[sample]
            input_ids = tokenizer(abstract, return_tensors="pt").input_ids
            input_ids = input_ids.to(device)
            outputs = model.generate(input_ids)
            print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        print()