In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, T5ForConditionalGeneration
import settings as settings
from typing import Tuple, Optional
import torch
import os, logging
from pathlib import Path


class Sentinel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.t5_eos = settings.t5_eos_str
        self.t5_model = T5ForConditionalGeneration.from_pretrained(**settings.sentinel['t5_model'])
        self.t5_tokenizer = AutoTokenizer.from_pretrained(**settings.sentinel['t5_tokenizer'])

    def forward(self, text: Tuple[str], label: Optional[Tuple[int]]):
        # encode (text, label)
        t5_text = self.t5_tokenizer.batch_encode_plus(text, **settings.sentinel['t5_tokenizer_text'])
        t5_text = t5_text.input_ids.to(settings.device)
        t5_label = self.t5_tokenizer.batch_encode_plus(label, **settings.sentinel['t5_tokenizer_label'])
        t5_label = t5_label.input_ids.to(settings.device)

        if self.training:
            t5_output = self.t5_model.forward(input_ids=t5_text, labels=t5_label)
            t5_loss, t5_logits = t5_output.loss, t5_output.logits
            t5_accuracy = torch.sum(
                torch.argmax(
                    F.softmax(t5_logits[:, 0, :], dim=-1), dim=-1
                ) == t5_label[:, 0]
            ) / settings.dataloader['train']['batch_size']
            return t5_loss, t5_accuracy
        else:
            t5_output = self.t5_model.generate(input_ids=t5_text, max_length=2, output_scores=True, return_dict_in_generate=True)
            t5_scores = t5_output.scores
            t5_accuracy = torch.sum(
                torch.argmax(
                    F.softmax(t5_scores[0], dim=-1), dim=-1
                ) == t5_label[:, 0]
            ) / settings.dataloader['valid']['batch_size']
            return t5_accuracy


##############################################################################
# Experiment
##############################################################################

debug  = False
resume = False

id = "t5.small.0424.d"
logger = logging.getLogger(id)
device = "cuda" if torch.cuda.is_available() else "cpu"

##############################################################################
# Hyperparameters
##############################################################################

epochIter = 10
batchSize = 64
learnRate = 5e-5
weigthDecay = 1e-3

optimizer = dict(
    lr=learnRate,
    weight_decay=weigthDecay,
)


##############################################################################
# Model
##############################################################################

sentinel = dict(
    t5_model=dict(
        pretrained_model_name_or_path="t5-small",
    ),
    t5_tokenizer=dict(
        pretrained_model_name_or_path="t5-small", 
        model_max_length=512,
        return_tensors="pt",
    ),
    t5_tokenizer_text = dict(
        max_length=512, 
        truncation=True,
        return_tensors="pt",
        padding="max_length",
    ),
    t5_tokenizer_label = dict(
        max_length=2, 
        truncation=True,
        return_tensors="pt",
    ),
)

##############################################################################
# Dataset
##############################################################################

t5_eos_str = "</s>"
t5_positive_token = 1465    # tokenizer.encode("positive")
t5_negative_token = 2841    # tokenizer.encode("negative")

dataset = dict(
    web_folder=Path(Path.home(), "GPT-Sentinel/data/open-web-text-split"),
    gpt_folder=Path(Path.home(), "GPT-Sentinel/data/open-gpt-text-split"),
)

dataloader = dict(
    train = dict(
        batch_size=32, shuffle=True, 
        num_workers=min(8, os.cpu_count()),
    ),
    valid = dict(
        batch_size=64, shuffle=True, 
        num_workers=min(8, os.cpu_count()),
    ),
    test = dict(
        batch_size=64, shuffle=True,
        num_workers=min(8, os.cpu_count()),
    )
)
identifier = "t5.small.0424.d"
directory = Path(f"./storage/{identifier}")

model = Sentinel().to("cuda")
state = torch.load(Path(directory, "state.pt"))
model.load_state_dict(state["model"])

<All keys matched successfully>

In [4]:
dataloader.get("valid")

{'batch_size': 64, 'shuffle': True, 'num_workers': 8}

In [3]:
text, label = next(iter())

ValueError: too many values to unpack (expected 2)