In [1]:
import torch
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('lvwerra/distilbert-imdb')

tokenizer

DistilBertTokenizerFast(name_or_path='lvwerra/distilbert-imdb', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [2]:
from datasets import load_dataset, concatenate_datasets

dataset = load_dataset('imdb')
dataset = concatenate_datasets([dataset[i] for i in ['train', 'test']])


def f(data):
    text = [i['text'] for i in data]
    label = [i['label'] for i in data]

    data = tokenizer(text,
                     padding=True,
                     truncation=True,
                     max_length=50,
                     return_tensors='pt').to(device)

    data['labels'] = torch.LongTensor(label).to(device)

    return data


loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True,
                                     collate_fn=f)

len(loader), next(iter(loader))

(12500,
 {'input_ids': tensor([[  101,  2004,  1037,  5470,  1997,  2381,  1010, 11327,  1010,  1998,
           5913,  1000, 17477,  7307,  1000,  2265,  2766,  2033,  1999,  2013,
           1996,  2131,  1011,  2175,  1012,  2009,  2038,  4100,  1011, 11519,
           5896,  2075,  1010,  1996, 12703,  2024, 10392,  1006,  2045,  2024,
          11790,  1007,  1010,  2009,  2038,  3492,  2204,  3772,  1010,   102],
         [  101, 24404, 29418, 12338,  1010,  1996,  2269,  1997,  1996,  3842,
           1999,  2010,  8795,  2005,  2634,  1005,  1055,  4071,  5998,  6439,
           2010,  2219,  2155,  1998,  2365,  1010,  2023,  3185,  2003,  2055,
           2010,  2365,  7632,  7941,  2389,  2040,  5683, 15486,  2138,  1997,
          24404, 29418, 12338,  1005,  1055,  2326,  2000,  1996,  2554,   102],
         [  101,  2004,  1037,  3627,  1010,  1045,  3046,  2000,  2424,  2004,
           2172,  1999,  3152,  2004,  1045,  4298,  2064,  2000,  5959,  2068,
           1012,

In [3]:
from transformers import AutoModelForSequenceClassification

model_critic = AutoModelForSequenceClassification.from_pretrained(
    'lvwerra/distilbert-imdb').to(device)

model_critic.config

DistilBertConfig {
  "_name_or_path": "lvwerra/distilbert-imdb",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "id2label": {
    "0": "NEGATIVE",
    "1": "POSITIVE"
  },
  "initializer_range": 0.02,
  "label2id": {
    "NEGATIVE": 0,
    "POSITIVE": 1
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "problem_type": "single_label_classification",
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "torch_dtype": "float32",
  "transformers_version": "4.43.3",
  "vocab_size": 30522
}

In [4]:
optimizer = torch.optim.Adam(model_critic.parameters(), lr=1e-5)

for i, data in enumerate(loader):
    out = model_critic(**data)
    out.loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 1000 == 0:
        acc = (out.logits.argmax(1) == data['labels']).sum() / len(
            data['labels'])
        print(i, len(loader), out.loss.item(), acc.item())
        
model_critic.save_pretrained('model/critic')

0 12500 1.1637792587280273 0.5
1000 12500 0.29049110412597656 1.0
2000 12500 0.832800030708313 0.5
3000 12500 0.1894889622926712 1.0
4000 12500 0.2516906261444092 1.0
5000 12500 0.4431069493293762 1.0
6000 12500 0.557310938835144 0.75
7000 12500 0.6333910822868347 0.5
8000 12500 0.23886911571025848 0.75
9000 12500 0.2795445919036865 1.0
10000 12500 0.09735985845327377 1.0
11000 12500 0.05919884890317917 1.0
12000 12500 0.24701425433158875 0.75
